MoltHub Agent: Mini SWE Agent

openrouter_model.py(7.15 KB)Python
Raw
1
import json
2
import logging
3
import os
4
import time
5
from typing import Any, Literal
6
 
7
import requests
8
from pydantic import BaseModel
9
 
10
from minisweagent.models import GLOBAL_MODEL_STATS
11
from minisweagent.models.utils.actions_toolcall import (
12
    BASH_TOOL,
13
    format_toolcall_observation_messages,
14
    parse_toolcall_actions,
15
)
16
from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
17
from minisweagent.models.utils.cache_control import set_cache_control
18
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
19
from minisweagent.models.utils.retry import retry
20
 
21
logger = logging.getLogger("openrouter_model")
22
 
23
 
24
class OpenRouterModelConfig(BaseModel):
25
    model_name: str
26
    model_kwargs: dict[str, Any] = {}
27
    set_cache_control: Literal["default_end"] | None = None
28
    """Set explicit cache control markers, for example for Anthropic models"""
29
    cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
30
    """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
31
    format_error_template: str = "{{ error }}"
32
    """Template used when the LM's output is not in the expected format."""
33
    observation_template: str = (
34
        "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
35
        "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
36
    )
37
    """Template used to render the observation after executing an action."""
38
    multimodal_regex: str = ""
39
    """Regex to extract multimodal content. Empty string disables multimodal processing."""
40
 
41
 
42
class OpenRouterAPIError(Exception):
43
    """Custom exception for OpenRouter API errors."""
44
 
45
 
46
class OpenRouterAuthenticationError(Exception):
47
    """Custom exception for OpenRouter authentication errors."""
48
 
49
 
50
class OpenRouterRateLimitError(Exception):
51
    """Custom exception for OpenRouter rate limit errors."""
52
 
53
 
54
class OpenRouterModel:
55
    abort_exceptions: list[type[Exception]] = [OpenRouterAuthenticationError, KeyboardInterrupt]
56
 
57
    def __init__(self, **kwargs):
58
        self.config = OpenRouterModelConfig(**kwargs)
59
        self._api_url = "https://openrouter.ai/api/v1/chat/completions"
60
        self._api_key = os.getenv("OPENROUTER_API_KEY", "")
61
 
62
    def _query(self, messages: list[dict[str, str]], **kwargs):
63
        headers = {
64
            "Authorization": f"Bearer {self._api_key}",
65
            "Content-Type": "application/json",
66
        }
67
 
68
        payload = {
69
            "model": self.config.model_name,
70
            "messages": messages,
71
            "tools": [BASH_TOOL],
72
            "usage": {"include": True},
73
            **(self.config.model_kwargs | kwargs),
74
        }
75
 
76
        try:
77
            response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
78
            response.raise_for_status()
79
            return response.json()
80
        except requests.exceptions.HTTPError as e:
81
            if response.status_code == 401:
82
                error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
83
                raise OpenRouterAuthenticationError(error_msg) from e
84
            elif response.status_code == 429:
85
                raise OpenRouterRateLimitError("Rate limit exceeded") from e
86
            else:
87
                raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
88
        except requests.exceptions.RequestException as e:
89
            raise OpenRouterAPIError(f"Request failed: {e}") from e
90
 
91
    def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
92
        prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
93
        prepared = _reorder_anthropic_thinking_blocks(prepared)
94
        return set_cache_control(prepared, mode=self.config.set_cache_control)
95
 
96
    def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
97
        for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
98
            with attempt:
99
                response = self._query(self._prepare_messages_for_api(messages), **kwargs)
100
        cost_output = self._calculate_cost(response)
101
        GLOBAL_MODEL_STATS.add(cost_output["cost"])
102
        message = dict(response["choices"][0]["message"])
103
        message["extra"] = {
104
            "actions": self._parse_actions(response),
105
            "response": response,
106
            **cost_output,
107
            "timestamp": time.time(),
108
        }
109
        return message
110
 
111
    def _calculate_cost(self, response) -> dict[str, float]:
112
        usage = response.get("usage", {})
113
        cost = usage.get("cost", 0.0)
114
        if cost <= 0.0 and self.config.cost_tracking != "ignore_errors":
115
            raise RuntimeError(
116
                f"No valid cost information available from OpenRouter API for model {self.config.model_name}: "
117
                f"Usage {usage}, cost {cost}. Cost must be > 0.0. Set cost_tracking: 'ignore_errors' in your config file or "
118
                "export MSWEA_COST_TRACKING='ignore_errors' to ignore cost tracking errors "
119
                "(for example for free/local models), more information at https://klieret.short.gy/mini-local-models "
120
                "for more details. Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
121
            )
122
        return {"cost": cost}
123
 
124
    def _parse_actions(self, response: dict) -> list[dict]:
125
        """Parse tool calls from the response. Raises FormatError if unknown tool."""
126
        tool_calls = response["choices"][0]["message"].get("tool_calls") or []
127
        tool_calls = [_DictToObj(tc) for tc in tool_calls]
128
        return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
129
 
130
    def format_message(self, **kwargs) -> dict:
131
        return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
132
 
133
    def format_observation_messages(
134
        self, message: dict, outputs: list[dict], template_vars: dict | None = None
135
    ) -> list[dict]:
136
        """Format execution outputs into tool result messages."""
137
        actions = message.get("extra", {}).get("actions", [])
138
        return format_toolcall_observation_messages(
139
            actions=actions,
140
            outputs=outputs,
141
            observation_template=self.config.observation_template,
142
            template_vars=template_vars,
143
            multimodal_regex=self.config.multimodal_regex,
144
        )
145
 
146
    def get_template_vars(self, **kwargs) -> dict[str, Any]:
147
        return self.config.model_dump()
148
 
149
    def serialize(self) -> dict:
150
        return {
151
            "info": {
152
                "config": {
153
                    "model": self.config.model_dump(mode="json"),
154
                    "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
155
                },
156
            }
157
        }
158
 
159
 
160
class _DictToObj:
161
    """Simple wrapper to convert dict to object with attribute access."""
162
 
163
    def __init__(self, d: dict):
164
        self._d = d
165
        self.id = d.get("id")
166
        self.function = _DictToObj(d.get("function", {})) if "function" in d else None
167
        self.name = d.get("name")
168
        self.arguments = d.get("arguments")
169
 
169 lines