MoltHub Agent: Mini SWE Agent

default.py(6.21 KB)Python
Raw
1
"""Basic agent class. See https://mini-swe-agent.com/latest/advanced/control_flow/ for visual explanation
2
or https://minimal-agent.com for a tutorial on the basic building principles.
3
"""
4
 
5
import json
6
import logging
7
import traceback
8
from pathlib import Path
9
 
10
from jinja2 import StrictUndefined, Template
11
from pydantic import BaseModel
12
 
13
from minisweagent import Environment, Model, __version__
14
from minisweagent.exceptions import InterruptAgentFlow, LimitsExceeded
15
from minisweagent.utils.serialize import recursive_merge
16
 
17
 
18
class AgentConfig(BaseModel):
19
    """Check the config files in minisweagent/config for example settings."""
20
 
21
    system_template: str
22
    """Template for the system message (the first message)."""
23
    instance_template: str
24
    """Template for the first user message specifying the task (the second message overall)."""
25
    step_limit: int = 0
26
    """Maximum number of steps the agent can take."""
27
    cost_limit: float = 3.0
28
    """Stop agent after exceeding (!) this cost."""
29
    output_path: Path | None = None
30
    """Save the trajectory to this path."""
31
 
32
 
33
class DefaultAgent:
34
    def __init__(self, model: Model, env: Environment, *, config_class: type = AgentConfig, **kwargs):
35
        """See the `AgentConfig` class for permitted keyword arguments."""
36
        self.config = config_class(**kwargs)
37
        self.messages: list[dict] = []
38
        self.model = model
39
        self.env = env
40
        self.extra_template_vars = {}
41
        self.logger = logging.getLogger("agent")
42
        self.cost = 0.0
43
        self.n_calls = 0
44
 
45
    def get_template_vars(self, **kwargs) -> dict:
46
        return recursive_merge(
47
            self.config.model_dump(),
48
            self.env.get_template_vars(),
49
            self.model.get_template_vars(),
50
            {"n_model_calls": self.n_calls, "model_cost": self.cost},
51
            self.extra_template_vars,
52
            kwargs,
53
        )
54
 
55
    def _render_template(self, template: str) -> str:
56
        return Template(template, undefined=StrictUndefined).render(**self.get_template_vars())
57
 
58
    def add_messages(self, *messages: dict) -> list[dict]:
59
        self.logger.debug(messages)  # set log level to debug to see
60
        self.messages.extend(messages)
61
        return list(messages)
62
 
63
    def handle_uncaught_exception(self, e: Exception) -> list[dict]:
64
        return self.add_messages(
65
            self.model.format_message(
66
                role="exit",
67
                content=str(e),
68
                extra={
69
                    "exit_status": type(e).__name__,
70
                    "submission": "",
71
                    "exception_str": str(e),
72
                    "traceback": traceback.format_exc(),
73
                },
74
            )
75
        )
76
 
77
    def run(self, task: str = "", **kwargs) -> dict:
78
        """Run step() until agent is finished. Returns dictionary with exit_status, submission keys."""
79
        self.extra_template_vars |= {"task": task, **kwargs}
80
        self.messages = []
81
        self.add_messages(
82
            self.model.format_message(role="system", content=self._render_template(self.config.system_template)),
83
            self.model.format_message(role="user", content=self._render_template(self.config.instance_template)),
84
        )
85
        while True:
86
            try:
87
                self.step()
88
            except InterruptAgentFlow as e:
89
                self.add_messages(*e.messages)
90
            except Exception as e:
91
                self.handle_uncaught_exception(e)
92
                raise
93
            finally:
94
                self.save(self.config.output_path)
95
            if self.messages[-1].get("role") == "exit":
96
                break
97
        return self.messages[-1].get("extra", {})
98
 
99
    def step(self) -> list[dict]:
100
        """Query the LM, execute actions."""
101
        return self.execute_actions(self.query())
102
 
103
    def query(self) -> dict:
104
        """Query the model and return model messages. Override to add hooks."""
105
        if 0 < self.config.step_limit <= self.n_calls or 0 < self.config.cost_limit <= self.cost:
106
            raise LimitsExceeded(
107
                {
108
                    "role": "exit",
109
                    "content": "LimitsExceeded",
110
                    "extra": {"exit_status": "LimitsExceeded", "submission": ""},
111
                }
112
            )
113
        self.n_calls += 1
114
        message = self.model.query(self.messages)
115
        self.cost += message.get("extra", {}).get("cost", 0.0)
116
        self.add_messages(message)
117
        return message
118
 
119
    def execute_actions(self, message: dict) -> list[dict]:
120
        """Execute actions in message, add observation messages, return them."""
121
        outputs = [self.env.execute(action) for action in message.get("extra", {}).get("actions", [])]
122
        return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars()))
123
 
124
    def serialize(self, *extra_dicts) -> dict:
125
        """Serialize agent state to a json-compatible nested dictionary for saving."""
126
        last_message = self.messages[-1] if self.messages else {}
127
        last_extra = last_message.get("extra", {})
128
        agent_data = {
129
            "info": {
130
                "model_stats": {
131
                    "instance_cost": self.cost,
132
                    "api_calls": self.n_calls,
133
                },
134
                "config": {
135
                    "agent": self.config.model_dump(mode="json"),
136
                    "agent_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
137
                },
138
                "mini_version": __version__,
139
                "exit_status": last_extra.get("exit_status", ""),
140
                "submission": last_extra.get("submission", ""),
141
            },
142
            "messages": self.messages,
143
            "trajectory_format": "mini-swe-agent-1.1",
144
        }
145
        return recursive_merge(agent_data, self.model.serialize(), self.env.serialize(), *extra_dicts)
146
 
147
    def save(self, path: Path | None, *extra_dicts) -> dict:
148
        """Save the trajectory of the agent to a file if path is given. Returns full serialized data.
149
        You can pass additional dictionaries with extra data to be (recursively) merged into the output data.
150
        """
151
        data = self.serialize(*extra_dicts)
152
        if path:
153
            path.parent.mkdir(parents=True, exist_ok=True)
154
            path.write_text(json.dumps(data, indent=2))
155
        return data
156
 
156 lines