| 1 | """A small generalization of the default agent that puts the user in the loop.
|
| 2 |
|
| 3 | There are three modes:
|
| 4 | - human: commands issued by the user are executed immediately
|
| 5 | - confirm: commands issued by the LM but not whitelisted are confirmed by the user
|
| 6 | - yolo: commands issued by the LM are executed immediately without confirmation
|
| 7 | """
|
| 8 |
|
| 9 | import re
|
| 10 | from typing import Literal, NoReturn
|
| 11 |
|
| 12 | from prompt_toolkit.formatted_text import HTML
|
| 13 | from prompt_toolkit.history import FileHistory
|
| 14 | from prompt_toolkit.shortcuts import PromptSession
|
| 15 | from rich.console import Console
|
| 16 | from rich.rule import Rule
|
| 17 |
|
| 18 | from minisweagent import global_config_dir
|
| 19 | from minisweagent.agents.default import AgentConfig, DefaultAgent
|
| 20 | from minisweagent.exceptions import LimitsExceeded, Submitted, UserInterruption
|
| 21 | from minisweagent.models.utils.content_string import get_content_string
|
| 22 |
|
| 23 | console = Console(highlight=False)
|
| 24 | _history = FileHistory(global_config_dir / "interactive_history.txt")
|
| 25 | _prompt_session = PromptSession(history=_history)
|
| 26 | _multiline_prompt_session = PromptSession(history=_history, multiline=True)
|
| 27 |
|
| 28 |
|
| 29 | class InteractiveAgentConfig(AgentConfig):
|
| 30 | mode: Literal["human", "confirm", "yolo"] = "confirm"
|
| 31 | """Whether to confirm actions."""
|
| 32 | whitelist_actions: list[str] = []
|
| 33 | """Never confirm actions that match these regular expressions."""
|
| 34 | confirm_exit: bool = True
|
| 35 | """If the agent wants to finish, do we ask for confirmation from user?"""
|
| 36 |
|
| 37 |
|
| 38 | def _multiline_prompt() -> str:
|
| 39 | return _multiline_prompt_session.prompt(
|
| 40 | "",
|
| 41 | bottom_toolbar=HTML(
|
| 42 | "Submit message: <b fg='yellow' bg='black'>Esc, then Enter</b> | "
|
| 43 | "Navigate history: <b fg='yellow' bg='black'>Arrow Up/Down</b> | "
|
| 44 | "Search history: <b fg='yellow' bg='black'>Ctrl+R</b>"
|
| 45 | ),
|
| 46 | )
|
| 47 |
|
| 48 |
|
| 49 | class InteractiveAgent(DefaultAgent):
|
| 50 | _MODE_COMMANDS_MAPPING = {"/u": "human", "/c": "confirm", "/y": "yolo"}
|
| 51 |
|
| 52 | def __init__(self, *args, config_class=InteractiveAgentConfig, **kwargs):
|
| 53 | super().__init__(*args, config_class=config_class, **kwargs)
|
| 54 | self.cost_last_confirmed = 0.0
|
| 55 |
|
| 56 | def add_messages(self, *messages: dict) -> list[dict]:
|
| 57 | # Extend supermethod to print messages
|
| 58 | for msg in messages:
|
| 59 | role, content = msg.get("role") or msg.get("type", "unknown"), get_content_string(msg)
|
| 60 | if role == "assistant":
|
| 61 | console.print(
|
| 62 | f"\n[red][bold]mini-swe-agent[/bold] (step [bold]{self.n_calls}[/bold], [bold]${self.cost:.2f}[/bold]):[/red]\n",
|
| 63 | end="",
|
| 64 | highlight=False,
|
| 65 | )
|
| 66 | else:
|
| 67 | console.print(f"\n[bold green]{role.capitalize()}[/bold green]:\n", end="", highlight=False)
|
| 68 | console.print(content, highlight=False, markup=False)
|
| 69 | return super().add_messages(*messages)
|
| 70 |
|
| 71 | def query(self) -> dict:
|
| 72 | # Extend supermethod to handle human mode
|
| 73 | if self.config.mode == "human":
|
| 74 | match command := self._prompt_and_handle_slash_commands("[bold yellow]>[/bold yellow] "):
|
| 75 | case "/y" | "/c":
|
| 76 | pass
|
| 77 | case _:
|
| 78 | msg = {
|
| 79 | "role": "user",
|
| 80 | "content": f"User command: \n```bash\n{command}\n```",
|
| 81 | "extra": {"actions": [{"command": command}]},
|
| 82 | }
|
| 83 | self.add_messages(msg)
|
| 84 | return msg
|
| 85 | try:
|
| 86 | with console.status("Waiting for the LM to respond..."):
|
| 87 | return super().query()
|
| 88 | except LimitsExceeded:
|
| 89 | console.print(
|
| 90 | f"Limits exceeded. Limits: {self.config.step_limit} steps, ${self.config.cost_limit}.\n"
|
| 91 | f"Current spend: {self.n_calls} steps, ${self.cost:.2f}."
|
| 92 | )
|
| 93 | self.config.step_limit = int(input("New step limit: "))
|
| 94 | self.config.cost_limit = float(input("New cost limit: "))
|
| 95 | return super().query()
|
| 96 |
|
| 97 | def step(self) -> list[dict]:
|
| 98 | # Override the step method to handle user interruption
|
| 99 | try:
|
| 100 | console.print(Rule())
|
| 101 | return super().step()
|
| 102 | except KeyboardInterrupt:
|
| 103 | interruption_message = self._prompt_and_handle_slash_commands(
|
| 104 | "\n\n[bold yellow]Interrupted.[/bold yellow] "
|
| 105 | "[green]Type a comment/command[/green] (/h for available commands)"
|
| 106 | "\n[bold yellow]>[/bold yellow] "
|
| 107 | ).strip()
|
| 108 | if not interruption_message or interruption_message in self._MODE_COMMANDS_MAPPING:
|
| 109 | interruption_message = "Temporary interruption caught."
|
| 110 | raise UserInterruption(
|
| 111 | {
|
| 112 | "role": "user",
|
| 113 | "content": f"Interrupted by user: {interruption_message}",
|
| 114 | "extra": {"interrupt_type": "UserInterruption"},
|
| 115 | }
|
| 116 | )
|
| 117 |
|
| 118 | def execute_actions(self, message: dict) -> list[dict]:
|
| 119 | # Override to handle user confirmation and confirm_exit, with try/finally to preserve partial outputs
|
| 120 | actions = message.get("extra", {}).get("actions", [])
|
| 121 | commands = [action["command"] for action in actions]
|
| 122 | outputs = []
|
| 123 | try:
|
| 124 | self._ask_confirmation_or_interrupt(commands)
|
| 125 | for action in actions:
|
| 126 | outputs.append(self.env.execute(action))
|
| 127 | except Submitted as e:
|
| 128 | self._check_for_new_task_or_submit(e)
|
| 129 | finally:
|
| 130 | result = self.add_messages(
|
| 131 | *self.model.format_observation_messages(message, outputs, self.get_template_vars())
|
| 132 | )
|
| 133 | return result
|
| 134 |
|
| 135 | def _add_observation_messages(self, message: dict, outputs: list[dict]) -> list[dict]:
|
| 136 | return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars()))
|
| 137 |
|
| 138 | def _check_for_new_task_or_submit(self, e: Submitted) -> NoReturn:
|
| 139 | """Check if user wants to add a new task or submit."""
|
| 140 | if self.config.confirm_exit:
|
| 141 | message = (
|
| 142 | "[bold yellow]Agent wants to finish.[/bold yellow] "
|
| 143 | "[bold green]Type new task[/bold green] or [red][bold]Esc, then enter[/bold] to quit.[/red]\n"
|
| 144 | "[bold yellow]>[/bold yellow] "
|
| 145 | )
|
| 146 | if new_task := self._prompt_and_handle_slash_commands(message, _multiline=True).strip():
|
| 147 | raise UserInterruption(
|
| 148 | {
|
| 149 | "role": "user",
|
| 150 | "content": f"The user added a new task: {new_task}",
|
| 151 | "extra": {"interrupt_type": "UserNewTask"},
|
| 152 | }
|
| 153 | )
|
| 154 | raise e
|
| 155 |
|
| 156 | def _should_ask_confirmation(self, action: str) -> bool:
|
| 157 | return self.config.mode == "confirm" and not any(re.match(r, action) for r in self.config.whitelist_actions)
|
| 158 |
|
| 159 | def _ask_confirmation_or_interrupt(self, commands: list[str]) -> None:
|
| 160 | commands_needing_confirmation = [c for c in commands if self._should_ask_confirmation(c)]
|
| 161 | if not commands_needing_confirmation:
|
| 162 | return
|
| 163 | n = len(commands_needing_confirmation)
|
| 164 | prompt = (
|
| 165 | f"[bold yellow]Execute {n} action(s)?[/] [green][bold]Enter[/] to confirm[/], "
|
| 166 | "[red]type [bold]comment[/] to reject[/], or [blue][bold]/h[/] to show available commands[/]\n"
|
| 167 | "[bold yellow]>[/bold yellow] "
|
| 168 | )
|
| 169 | match user_input := self._prompt_and_handle_slash_commands(prompt).strip():
|
| 170 | case "" | "/y":
|
| 171 | pass # confirmed, do nothing
|
| 172 | case "/u": # Skip execution action and get back to query
|
| 173 | raise UserInterruption(
|
| 174 | {
|
| 175 | "role": "user",
|
| 176 | "content": "Commands not executed. Switching to human mode",
|
| 177 | "extra": {"interrupt_type": "UserRejection"},
|
| 178 | }
|
| 179 | )
|
| 180 | case _:
|
| 181 | raise UserInterruption(
|
| 182 | {
|
| 183 | "role": "user",
|
| 184 | "content": f"Commands not executed. The user rejected your commands with the following message: {user_input}",
|
| 185 | "extra": {"interrupt_type": "UserRejection"},
|
| 186 | }
|
| 187 | )
|
| 188 |
|
| 189 | def _prompt_and_handle_slash_commands(self, prompt: str, *, _multiline: bool = False) -> str:
|
| 190 | """Prompts the user, takes care of /h (followed by requery) and sets the mode. Returns the user input."""
|
| 191 | console.print(prompt, end="")
|
| 192 | if _multiline:
|
| 193 | return _multiline_prompt()
|
| 194 | user_input = _prompt_session.prompt("")
|
| 195 | if user_input == "/m":
|
| 196 | return self._prompt_and_handle_slash_commands(prompt, _multiline=True)
|
| 197 | if user_input == "/h":
|
| 198 | console.print(
|
| 199 | f"Current mode: [bold green]{self.config.mode}[/bold green]\n"
|
| 200 | f"[bold green]/y[/bold green] to switch to [bold yellow]yolo[/bold yellow] mode (execute LM commands without confirmation)\n"
|
| 201 | f"[bold green]/c[/bold green] to switch to [bold yellow]confirmation[/bold yellow] mode (ask for confirmation before executing LM commands)\n"
|
| 202 | f"[bold green]/u[/bold green] to switch to [bold yellow]human[/bold yellow] mode (execute commands issued by the user)\n"
|
| 203 | f"[bold green]/m[/bold green] to enter multiline comment",
|
| 204 | )
|
| 205 | return self._prompt_and_handle_slash_commands(prompt)
|
| 206 | if user_input in self._MODE_COMMANDS_MAPPING:
|
| 207 | if self.config.mode == self._MODE_COMMANDS_MAPPING[user_input]:
|
| 208 | return self._prompt_and_handle_slash_commands(
|
| 209 | f"[bold red]Already in {self.config.mode} mode.[/bold red]\n{prompt}"
|
| 210 | )
|
| 211 | self.config.mode = self._MODE_COMMANDS_MAPPING[user_input]
|
| 212 | console.print(f"Switched to [bold green]{self.config.mode}[/bold green] mode.")
|
| 213 | return user_input
|
| 214 | return user_input
|
| 215 |
|