| 1 | """Basic agent class. See https://mini-swe-agent.com/latest/advanced/control_flow/ for visual explanation
|
| 2 | or https://minimal-agent.com for a tutorial on the basic building principles.
|
| 3 | """
|
| 4 |
|
| 5 | import json
|
| 6 | import logging
|
| 7 | import traceback
|
| 8 | from pathlib import Path
|
| 9 |
|
| 10 | from jinja2 import StrictUndefined, Template
|
| 11 | from pydantic import BaseModel
|
| 12 |
|
| 13 | from minisweagent import Environment, Model, __version__
|
| 14 | from minisweagent.exceptions import InterruptAgentFlow, LimitsExceeded
|
| 15 | from minisweagent.utils.serialize import recursive_merge
|
| 16 |
|
| 17 |
|
| 18 | class AgentConfig(BaseModel):
|
| 19 | """Check the config files in minisweagent/config for example settings."""
|
| 20 |
|
| 21 | system_template: str
|
| 22 | """Template for the system message (the first message)."""
|
| 23 | instance_template: str
|
| 24 | """Template for the first user message specifying the task (the second message overall)."""
|
| 25 | step_limit: int = 0
|
| 26 | """Maximum number of steps the agent can take."""
|
| 27 | cost_limit: float = 3.0
|
| 28 | """Stop agent after exceeding (!) this cost."""
|
| 29 | output_path: Path | None = None
|
| 30 | """Save the trajectory to this path."""
|
| 31 |
|
| 32 |
|
| 33 | class DefaultAgent:
|
| 34 | def __init__(self, model: Model, env: Environment, *, config_class: type = AgentConfig, **kwargs):
|
| 35 | """See the `AgentConfig` class for permitted keyword arguments."""
|
| 36 | self.config = config_class(**kwargs)
|
| 37 | self.messages: list[dict] = []
|
| 38 | self.model = model
|
| 39 | self.env = env
|
| 40 | self.extra_template_vars = {}
|
| 41 | self.logger = logging.getLogger("agent")
|
| 42 | self.cost = 0.0
|
| 43 | self.n_calls = 0
|
| 44 |
|
| 45 | def get_template_vars(self, **kwargs) -> dict:
|
| 46 | return recursive_merge(
|
| 47 | self.config.model_dump(),
|
| 48 | self.env.get_template_vars(),
|
| 49 | self.model.get_template_vars(),
|
| 50 | {"n_model_calls": self.n_calls, "model_cost": self.cost},
|
| 51 | self.extra_template_vars,
|
| 52 | kwargs,
|
| 53 | )
|
| 54 |
|
| 55 | def _render_template(self, template: str) -> str:
|
| 56 | return Template(template, undefined=StrictUndefined).render(**self.get_template_vars())
|
| 57 |
|
| 58 | def add_messages(self, *messages: dict) -> list[dict]:
|
| 59 | self.logger.debug(messages) # set log level to debug to see
|
| 60 | self.messages.extend(messages)
|
| 61 | return list(messages)
|
| 62 |
|
| 63 | def handle_uncaught_exception(self, e: Exception) -> list[dict]:
|
| 64 | return self.add_messages(
|
| 65 | self.model.format_message(
|
| 66 | role="exit",
|
| 67 | content=str(e),
|
| 68 | extra={
|
| 69 | "exit_status": type(e).__name__,
|
| 70 | "submission": "",
|
| 71 | "exception_str": str(e),
|
| 72 | "traceback": traceback.format_exc(),
|
| 73 | },
|
| 74 | )
|
| 75 | )
|
| 76 |
|
| 77 | def run(self, task: str = "", **kwargs) -> dict:
|
| 78 | """Run step() until agent is finished. Returns dictionary with exit_status, submission keys."""
|
| 79 | self.extra_template_vars |= {"task": task, **kwargs}
|
| 80 | self.messages = []
|
| 81 | self.add_messages(
|
| 82 | self.model.format_message(role="system", content=self._render_template(self.config.system_template)),
|
| 83 | self.model.format_message(role="user", content=self._render_template(self.config.instance_template)),
|
| 84 | )
|
| 85 | while True:
|
| 86 | try:
|
| 87 | self.step()
|
| 88 | except InterruptAgentFlow as e:
|
| 89 | self.add_messages(*e.messages)
|
| 90 | except Exception as e:
|
| 91 | self.handle_uncaught_exception(e)
|
| 92 | raise
|
| 93 | finally:
|
| 94 | self.save(self.config.output_path)
|
| 95 | if self.messages[-1].get("role") == "exit":
|
| 96 | break
|
| 97 | return self.messages[-1].get("extra", {})
|
| 98 |
|
| 99 | def step(self) -> list[dict]:
|
| 100 | """Query the LM, execute actions."""
|
| 101 | return self.execute_actions(self.query())
|
| 102 |
|
| 103 | def query(self) -> dict:
|
| 104 | """Query the model and return model messages. Override to add hooks."""
|
| 105 | if 0 < self.config.step_limit <= self.n_calls or 0 < self.config.cost_limit <= self.cost:
|
| 106 | raise LimitsExceeded(
|
| 107 | {
|
| 108 | "role": "exit",
|
| 109 | "content": "LimitsExceeded",
|
| 110 | "extra": {"exit_status": "LimitsExceeded", "submission": ""},
|
| 111 | }
|
| 112 | )
|
| 113 | self.n_calls += 1
|
| 114 | message = self.model.query(self.messages)
|
| 115 | self.cost += message.get("extra", {}).get("cost", 0.0)
|
| 116 | self.add_messages(message)
|
| 117 | return message
|
| 118 |
|
| 119 | def execute_actions(self, message: dict) -> list[dict]:
|
| 120 | """Execute actions in message, add observation messages, return them."""
|
| 121 | outputs = [self.env.execute(action) for action in message.get("extra", {}).get("actions", [])]
|
| 122 | return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars()))
|
| 123 |
|
| 124 | def serialize(self, *extra_dicts) -> dict:
|
| 125 | """Serialize agent state to a json-compatible nested dictionary for saving."""
|
| 126 | last_message = self.messages[-1] if self.messages else {}
|
| 127 | last_extra = last_message.get("extra", {})
|
| 128 | agent_data = {
|
| 129 | "info": {
|
| 130 | "model_stats": {
|
| 131 | "instance_cost": self.cost,
|
| 132 | "api_calls": self.n_calls,
|
| 133 | },
|
| 134 | "config": {
|
| 135 | "agent": self.config.model_dump(mode="json"),
|
| 136 | "agent_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
| 137 | },
|
| 138 | "mini_version": __version__,
|
| 139 | "exit_status": last_extra.get("exit_status", ""),
|
| 140 | "submission": last_extra.get("submission", ""),
|
| 141 | },
|
| 142 | "messages": self.messages,
|
| 143 | "trajectory_format": "mini-swe-agent-1.1",
|
| 144 | }
|
| 145 | return recursive_merge(agent_data, self.model.serialize(), self.env.serialize(), *extra_dicts)
|
| 146 |
|
| 147 | def save(self, path: Path | None, *extra_dicts) -> dict:
|
| 148 | """Save the trajectory of the agent to a file if path is given. Returns full serialized data.
|
| 149 | You can pass additional dictionaries with extra data to be (recursively) merged into the output data.
|
| 150 | """
|
| 151 | data = self.serialize(*extra_dicts)
|
| 152 | if path:
|
| 153 | path.parent.mkdir(parents=True, exist_ok=True)
|
| 154 | path.write_text(json.dumps(data, indent=2))
|
| 155 | return data
|
| 156 |
|