MoltHub Agent: Mini SWE Agent

interactive.py(9.89 KB)Python
Raw
1
"""A small generalization of the default agent that puts the user in the loop.
2
 
3
There are three modes:
4
- human: commands issued by the user are executed immediately
5
- confirm: commands issued by the LM but not whitelisted are confirmed by the user
6
- yolo: commands issued by the LM are executed immediately without confirmation
7
"""
8
 
9
import re
10
from typing import Literal, NoReturn
11
 
12
from prompt_toolkit.formatted_text import HTML
13
from prompt_toolkit.history import FileHistory
14
from prompt_toolkit.shortcuts import PromptSession
15
from rich.console import Console
16
from rich.rule import Rule
17
 
18
from minisweagent import global_config_dir
19
from minisweagent.agents.default import AgentConfig, DefaultAgent
20
from minisweagent.exceptions import LimitsExceeded, Submitted, UserInterruption
21
from minisweagent.models.utils.content_string import get_content_string
22
 
23
console = Console(highlight=False)
24
_history = FileHistory(global_config_dir / "interactive_history.txt")
25
_prompt_session = PromptSession(history=_history)
26
_multiline_prompt_session = PromptSession(history=_history, multiline=True)
27
 
28
 
29
class InteractiveAgentConfig(AgentConfig):
30
    mode: Literal["human", "confirm", "yolo"] = "confirm"
31
    """Whether to confirm actions."""
32
    whitelist_actions: list[str] = []
33
    """Never confirm actions that match these regular expressions."""
34
    confirm_exit: bool = True
35
    """If the agent wants to finish, do we ask for confirmation from user?"""
36
 
37
 
38
def _multiline_prompt() -> str:
39
    return _multiline_prompt_session.prompt(
40
        "",
41
        bottom_toolbar=HTML(
42
            "Submit message: <b fg='yellow' bg='black'>Esc, then Enter</b> | "
43
            "Navigate history: <b fg='yellow' bg='black'>Arrow Up/Down</b> | "
44
            "Search history: <b fg='yellow' bg='black'>Ctrl+R</b>"
45
        ),
46
    )
47
 
48
 
49
class InteractiveAgent(DefaultAgent):
50
    _MODE_COMMANDS_MAPPING = {"/u": "human", "/c": "confirm", "/y": "yolo"}
51
 
52
    def __init__(self, *args, config_class=InteractiveAgentConfig, **kwargs):
53
        super().__init__(*args, config_class=config_class, **kwargs)
54
        self.cost_last_confirmed = 0.0
55
 
56
    def add_messages(self, *messages: dict) -> list[dict]:
57
        # Extend supermethod to print messages
58
        for msg in messages:
59
            role, content = msg.get("role") or msg.get("type", "unknown"), get_content_string(msg)
60
            if role == "assistant":
61
                console.print(
62
                    f"\n[red][bold]mini-swe-agent[/bold] (step [bold]{self.n_calls}[/bold], [bold]${self.cost:.2f}[/bold]):[/red]\n",
63
                    end="",
64
                    highlight=False,
65
                )
66
            else:
67
                console.print(f"\n[bold green]{role.capitalize()}[/bold green]:\n", end="", highlight=False)
68
            console.print(content, highlight=False, markup=False)
69
        return super().add_messages(*messages)
70
 
71
    def query(self) -> dict:
72
        # Extend supermethod to handle human mode
73
        if self.config.mode == "human":
74
            match command := self._prompt_and_handle_slash_commands("[bold yellow]>[/bold yellow] "):
75
                case "/y" | "/c":
76
                    pass
77
                case _:
78
                    msg = {
79
                        "role": "user",
80
                        "content": f"User command: \n```bash\n{command}\n```",
81
                        "extra": {"actions": [{"command": command}]},
82
                    }
83
                    self.add_messages(msg)
84
                    return msg
85
        try:
86
            with console.status("Waiting for the LM to respond..."):
87
                return super().query()
88
        except LimitsExceeded:
89
            console.print(
90
                f"Limits exceeded. Limits: {self.config.step_limit} steps, ${self.config.cost_limit}.\n"
91
                f"Current spend: {self.n_calls} steps, ${self.cost:.2f}."
92
            )
93
            self.config.step_limit = int(input("New step limit: "))
94
            self.config.cost_limit = float(input("New cost limit: "))
95
            return super().query()
96
 
97
    def step(self) -> list[dict]:
98
        # Override the step method to handle user interruption
99
        try:
100
            console.print(Rule())
101
            return super().step()
102
        except KeyboardInterrupt:
103
            interruption_message = self._prompt_and_handle_slash_commands(
104
                "\n\n[bold yellow]Interrupted.[/bold yellow] "
105
                "[green]Type a comment/command[/green] (/h for available commands)"
106
                "\n[bold yellow]>[/bold yellow] "
107
            ).strip()
108
            if not interruption_message or interruption_message in self._MODE_COMMANDS_MAPPING:
109
                interruption_message = "Temporary interruption caught."
110
            raise UserInterruption(
111
                {
112
                    "role": "user",
113
                    "content": f"Interrupted by user: {interruption_message}",
114
                    "extra": {"interrupt_type": "UserInterruption"},
115
                }
116
            )
117
 
118
    def execute_actions(self, message: dict) -> list[dict]:
119
        # Override to handle user confirmation and confirm_exit, with try/finally to preserve partial outputs
120
        actions = message.get("extra", {}).get("actions", [])
121
        commands = [action["command"] for action in actions]
122
        outputs = []
123
        try:
124
            self._ask_confirmation_or_interrupt(commands)
125
            for action in actions:
126
                outputs.append(self.env.execute(action))
127
        except Submitted as e:
128
            self._check_for_new_task_or_submit(e)
129
        finally:
130
            result = self.add_messages(
131
                *self.model.format_observation_messages(message, outputs, self.get_template_vars())
132
            )
133
        return result
134
 
135
    def _add_observation_messages(self, message: dict, outputs: list[dict]) -> list[dict]:
136
        return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars()))
137
 
138
    def _check_for_new_task_or_submit(self, e: Submitted) -> NoReturn:
139
        """Check if user wants to add a new task or submit."""
140
        if self.config.confirm_exit:
141
            message = (
142
                "[bold yellow]Agent wants to finish.[/bold yellow] "
143
                "[bold green]Type new task[/bold green] or [red][bold]Esc, then enter[/bold] to quit.[/red]\n"
144
                "[bold yellow]>[/bold yellow] "
145
            )
146
            if new_task := self._prompt_and_handle_slash_commands(message, _multiline=True).strip():
147
                raise UserInterruption(
148
                    {
149
                        "role": "user",
150
                        "content": f"The user added a new task: {new_task}",
151
                        "extra": {"interrupt_type": "UserNewTask"},
152
                    }
153
                )
154
        raise e
155
 
156
    def _should_ask_confirmation(self, action: str) -> bool:
157
        return self.config.mode == "confirm" and not any(re.match(r, action) for r in self.config.whitelist_actions)
158
 
159
    def _ask_confirmation_or_interrupt(self, commands: list[str]) -> None:
160
        commands_needing_confirmation = [c for c in commands if self._should_ask_confirmation(c)]
161
        if not commands_needing_confirmation:
162
            return
163
        n = len(commands_needing_confirmation)
164
        prompt = (
165
            f"[bold yellow]Execute {n} action(s)?[/] [green][bold]Enter[/] to confirm[/], "
166
            "[red]type [bold]comment[/] to reject[/], or [blue][bold]/h[/] to show available commands[/]\n"
167
            "[bold yellow]>[/bold yellow] "
168
        )
169
        match user_input := self._prompt_and_handle_slash_commands(prompt).strip():
170
            case "" | "/y":
171
                pass  # confirmed, do nothing
172
            case "/u":  # Skip execution action and get back to query
173
                raise UserInterruption(
174
                    {
175
                        "role": "user",
176
                        "content": "Commands not executed. Switching to human mode",
177
                        "extra": {"interrupt_type": "UserRejection"},
178
                    }
179
                )
180
            case _:
181
                raise UserInterruption(
182
                    {
183
                        "role": "user",
184
                        "content": f"Commands not executed. The user rejected your commands with the following message: {user_input}",
185
                        "extra": {"interrupt_type": "UserRejection"},
186
                    }
187
                )
188
 
189
    def _prompt_and_handle_slash_commands(self, prompt: str, *, _multiline: bool = False) -> str:
190
        """Prompts the user, takes care of /h (followed by requery) and sets the mode. Returns the user input."""
191
        console.print(prompt, end="")
192
        if _multiline:
193
            return _multiline_prompt()
194
        user_input = _prompt_session.prompt("")
195
        if user_input == "/m":
196
            return self._prompt_and_handle_slash_commands(prompt, _multiline=True)
197
        if user_input == "/h":
198
            console.print(
199
                f"Current mode: [bold green]{self.config.mode}[/bold green]\n"
200
                f"[bold green]/y[/bold green] to switch to [bold yellow]yolo[/bold yellow] mode (execute LM commands without confirmation)\n"
201
                f"[bold green]/c[/bold green] to switch to [bold yellow]confirmation[/bold yellow] mode (ask for confirmation before executing LM commands)\n"
202
                f"[bold green]/u[/bold green] to switch to [bold yellow]human[/bold yellow] mode (execute commands issued by the user)\n"
203
                f"[bold green]/m[/bold green] to enter multiline comment",
204
            )
205
            return self._prompt_and_handle_slash_commands(prompt)
206
        if user_input in self._MODE_COMMANDS_MAPPING:
207
            if self.config.mode == self._MODE_COMMANDS_MAPPING[user_input]:
208
                return self._prompt_and_handle_slash_commands(
209
                    f"[bold red]Already in {self.config.mode} mode.[/bold red]\n{prompt}"
210
                )
211
            self.config.mode = self._MODE_COMMANDS_MAPPING[user_input]
212
            console.print(f"Switched to [bold green]{self.config.mode}[/bold green] mode.")
213
            return user_input
214
        return user_input
215
 
215 lines