Skip to content

Commit

Permalink
[Feat] Custom MicroAgents. (#4983)
Browse files Browse the repository at this point in the history
Co-authored-by: diwu-sf <[email protected]>
  • Loading branch information
RajWorking and diwu-sf authored Dec 6, 2024
1 parent cf157c8 commit 2b06e4e
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 28 deletions.
9 changes: 8 additions & 1 deletion frontend/src/context/ws-client-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,22 @@ interface WsClientProviderProps {
enabled: boolean;
token: string | null;
ghToken: string | null;
selectedRepository: string | null;
settings: Settings | null;
}

export function WsClientProvider({
enabled,
token,
ghToken,
selectedRepository,
settings,
children,
}: React.PropsWithChildren<WsClientProviderProps>) {
const sioRef = React.useRef<Socket | null>(null);
const tokenRef = React.useRef<string | null>(token);
const ghTokenRef = React.useRef<string | null>(ghToken);
const selectedRepositoryRef = React.useRef<string | null>(selectedRepository);
const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
null,
);
Expand Down Expand Up @@ -81,6 +84,9 @@ export function WsClientProvider({
if (ghToken) {
initEvent.github_token = ghToken;
}
if (selectedRepository) {
initEvent.selected_repository = selectedRepository;
}
const lastEvent = lastEventRef.current;
if (lastEvent) {
initEvent.latest_event_id = lastEvent.id;
Expand Down Expand Up @@ -158,6 +164,7 @@ export function WsClientProvider({
sioRef.current = sio;
tokenRef.current = token;
ghTokenRef.current = ghToken;
selectedRepositoryRef.current = selectedRepository;

return () => {
sio.off("connect", handleConnect);
Expand All @@ -166,7 +173,7 @@ export function WsClientProvider({
sio.off("connect_failed", handleError);
sio.off("disconnect", handleDisconnect);
};
}, [enabled, token, ghToken]);
}, [enabled, token, ghToken, selectedRepository]);

// Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
// before actually disconnecting the socket and cancel the operation if the component gets remounted.
Expand Down
9 changes: 1 addition & 8 deletions frontend/src/routes/_oh.app/hooks/use-ws-status-change.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {
WsClientProviderStatus,
} from "#/context/ws-client-provider";
import { createChatMessage } from "#/services/chat-service";
import { getCloneRepoCommand } from "#/services/terminal-service";
import { setCurrentAgentState } from "#/state/agent-slice";
import { addUserMessage } from "#/state/chat-slice";
import {
Expand Down Expand Up @@ -37,11 +36,6 @@ export const useWSStatusChange = () => {
send(createChatMessage(query, base64Files, timestamp));
};

const dispatchCloneRepoCommand = (ghToken: string, repository: string) => {
send(getCloneRepoCommand(ghToken, repository));
dispatch(clearSelectedRepository());
};

const dispatchInitialQuery = (query: string, additionalInfo: string) => {
if (additionalInfo) {
sendInitialQuery(`${query}\n\n[${additionalInfo}]`, files);
Expand All @@ -57,8 +51,7 @@ export const useWSStatusChange = () => {
let additionalInfo = "";

if (gitHubToken && selectedRepository) {
dispatchCloneRepoCommand(gitHubToken, selectedRepository);
additionalInfo = `Repository ${selectedRepository} has been cloned to /workspace. Please check the /workspace for files.`;
dispatch(clearSelectedRepository());
} else if (importedProjectZip) {
// if there's an uploaded project zip, add it to the chat
additionalInfo =
Expand Down
1 change: 1 addition & 0 deletions frontend/src/routes/_oh.app/route.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ function App() {
enabled
token={token}
ghToken={gitHubToken}
selectedRepository={selectedRepository}
settings={settings}
>
<EventHandler>
Expand Down
8 changes: 0 additions & 8 deletions frontend/src/services/terminal-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,3 @@ export function getGitHubTokenCommand(gitHubToken: string) {
const event = getTerminalCommand(command, true);
return event;
}

export function getCloneRepoCommand(gitHubToken: string, repository: string) {
const url = `https://${gitHubToken}@github.com/${repository}.git`;
const dirName = repository.split("/")[1];
const command = `git clone ${url} ${dirName} ; cd ${dirName} ; git checkout -b openhands-workspace`;
const event = getTerminalCommand(command, true);
return event;
}
3 changes: 3 additions & 0 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ def _get_messages(self, state: State) -> list[Message]:
- Messages from the same role are combined to prevent consecutive same-role messages
- For Anthropic models, specific messages are cached according to their documentation
"""
if not self.prompt_manager:
raise Exception('Prompt Manager not instantiated.')

messages: list[Message] = [
Message(
role='system',
Expand Down
2 changes: 2 additions & 0 deletions openhands/controller/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from openhands.llm.llm import LLM
from openhands.runtime.plugins import PluginRequirement
from openhands.utils.prompt import PromptManager


class Agent(ABC):
Expand All @@ -33,6 +34,7 @@ def __init__(
self.llm = llm
self.config = config
self._complete = False
self.prompt_manager: PromptManager | None = None

@property
def complete(self) -> bool:
Expand Down
41 changes: 41 additions & 0 deletions openhands/runtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,47 @@ async def on_event(self, event: Event) -> None:
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]

def clone_repo(self, github_token: str | None, selected_repository: str | None):
if not github_token or not selected_repository:
return
url = f'https://{github_token}@github.com/{selected_repository}.git'
dir_name = selected_repository.split('/')[1]
action = CmdRunAction(
command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b openhands-workspace'
)
self.log('info', 'Cloning repo: {selected_repository}')
self.run_action(action)

def get_custom_microagents(self, selected_repository: str | None) -> list[str]:
custom_microagents_content = []
custom_microagents_dir = Path('.openhands') / 'microagents'

dir_name = str(custom_microagents_dir)
if selected_repository:
dir_name = str(
Path(selected_repository.split('/')[1]) / custom_microagents_dir
)
oh_instructions_header = '---\nname: openhands_instructions\nagent: CodeActAgent\ntriggers:\n- ""\n---\n'
obs = self.read(FileReadAction(path='.openhands_instructions'))
if isinstance(obs, ErrorObservation):
self.log('error', 'Failed to read openhands_instructions')
else:
openhands_instructions = oh_instructions_header + obs.content
self.log('info', f'openhands_instructions: {openhands_instructions}')
custom_microagents_content.append(openhands_instructions)

files = self.list_files(dir_name)

self.log('info', f'Found {len(files)} custom microagents.')

for fname in files:
content = self.read(
FileReadAction(path=str(custom_microagents_dir / fname))
).content
custom_microagents_content.append(content)

return custom_microagents_content

def run_action(self, action: Action) -> Observation:
"""Run an action and return the resulting observation.
If the action is not runnable in any runtime, a NullObservation is returned.
Expand Down
2 changes: 2 additions & 0 deletions openhands/server/listen_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ async def oh_action(connection_id: str, data: dict):
latest_event_id = int(data.pop('latest_event_id', -1))
kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
session_init_data = SessionInitData(**kwargs)
session_init_data.github_token = github_token
session_init_data.selected_repository = data.get('selected_repository', None)
await init_connection(
connection_id, token, github_token, session_init_data, latest_event_id
)
Expand Down
19 changes: 18 additions & 1 deletion openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from openhands.core.config import AgentConfig, AppConfig, LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action.agent import ChangeAgentStateAction
from openhands.events.action import ChangeAgentStateAction
from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.runtime import get_runtime_cls
Expand Down Expand Up @@ -60,6 +60,8 @@ async def start(
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
selected_repository: str | None = None,
):
"""Starts the Agent session
Parameters:
Expand All @@ -86,6 +88,8 @@ async def start(
max_budget_per_task,
agent_to_llm_config,
agent_configs,
github_token,
selected_repository,
)

def _start_thread(self, *args):
Expand All @@ -104,13 +108,18 @@ async def _start(
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
selected_repository: str | None = None,
):
self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(
runtime_name=runtime_name,
config=config,
agent=agent,
github_token=github_token,
selected_repository=selected_repository,
)

self._create_controller(
agent,
config.security.confirmation_mode,
Expand Down Expand Up @@ -165,6 +174,8 @@ async def _create_runtime(
runtime_name: str,
config: AppConfig,
agent: Agent,
github_token: str | None = None,
selected_repository: str | None = None,
):
"""Creates a runtime instance
Expand Down Expand Up @@ -199,6 +210,12 @@ async def _create_runtime(
return

if self.runtime is not None:
self.runtime.clone_repo(github_token, selected_repository)
if agent.prompt_manager:
agent.prompt_manager.load_microagent_files(
self.runtime.get_custom_microagents(selected_repository)
)

logger.debug(
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
)
Expand Down
3 changes: 2 additions & 1 deletion openhands/server/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ async def initialize_agent(self, session_init_data: SessionInitData):
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
max_iterations = session_init_data.max_iterations or self.config.max_iterations
# override default LLM config


default_llm_config = self.config.get_llm_config()
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
Expand All @@ -94,6 +93,8 @@ async def initialize_agent(self, session_init_data: SessionInitData):
max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
github_token=session_init_data.github_token,
selected_repository=session_init_data.selected_repository,
)
except Exception as e:
logger.exception(f'Error creating controller: {e}')
Expand Down
2 changes: 2 additions & 0 deletions openhands/server/session/session_init_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ class SessionInitData:
llm_model: str | None = None
llm_api_key: str | None = None
llm_base_url: str | None = None
github_token: str | None = None
selected_repository: str | None = None
22 changes: 14 additions & 8 deletions openhands/utils/microagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@ class MicroAgentMetadata(pydantic.BaseModel):


class MicroAgent:
def __init__(self, path: str):
self.path = path
if not os.path.exists(path):
raise FileNotFoundError(f'Micro agent file {path} is not found')
with open(path, 'r') as file:
self._loaded = frontmatter.load(file)
self._content = self._loaded.content
self._metadata = MicroAgentMetadata(**self._loaded.metadata)
def __init__(self, path: str | None = None, content: str | None = None):
if path and not content:
self.path = path
if not os.path.exists(path):
raise FileNotFoundError(f'Micro agent file {path} is not found')
with open(path, 'r') as file:
self._loaded = frontmatter.load(file)
self._content = self._loaded.content
self._metadata = MicroAgentMetadata(**self._loaded.metadata)
elif content and not path:
self._metadata, self._content = frontmatter.parse(content)
self._metadata = MicroAgentMetadata(**self._metadata)
else:
raise Exception('You must pass either path or file content, but not both.')

def get_trigger(self, message: str) -> str | None:
message = message.lower()
Expand Down
7 changes: 6 additions & 1 deletion openhands/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,18 @@ def __init__(
if f.endswith('.md')
]
for microagent_file in microagent_files:
microagent = MicroAgent(microagent_file)
microagent = MicroAgent(path=microagent_file)
if (
disabled_microagents is None
or microagent.name not in disabled_microagents
):
self.microagents[microagent.name] = microagent

def load_microagent_files(self, microagent_files: list[str]):
for microagent_file in microagent_files:
microagent = MicroAgent(content=microagent_file)
self.microagents[microagent.name] = microagent

def _load_template(self, template_name: str) -> Template:
if self.prompt_dir is None:
raise ValueError('Prompt directory is not set')
Expand Down

0 comments on commit 2b06e4e

Please sign in to comment.