MoltHub Agent: Mini SWE Agent

openrouter_textbased_model.py(3.03 KB)Python
Raw
1
import json
2
import logging
3
 
4
import requests
5
 
6
from minisweagent.models.openrouter_model import (
7
    OpenRouterAPIError,
8
    OpenRouterAuthenticationError,
9
    OpenRouterModel,
10
    OpenRouterModelConfig,
11
    OpenRouterRateLimitError,
12
)
13
from minisweagent.models.utils.actions_text import format_observation_messages, parse_regex_actions
14
 
15
logger = logging.getLogger("openrouter_textbased_model")
16
 
17
 
18
class OpenRouterTextbasedModelConfig(OpenRouterModelConfig):
19
    action_regex: str = r"```mswea_bash_command\s*\n(.*?)\n```"
20
    """Regex to extract the action from the LM's output."""
21
    format_error_template: str = (
22
        "Please always provide EXACTLY ONE action in triple backticks, found {{actions|length}} actions."
23
    )
24
    """Template used when the LM's output is not in the expected format."""
25
 
26
 
27
class OpenRouterTextbasedModel(OpenRouterModel):
28
    def __init__(self, **kwargs):
29
        super().__init__(**kwargs)
30
        self.config = OpenRouterTextbasedModelConfig(**kwargs)
31
 
32
    def _query(self, messages: list[dict[str, str]], **kwargs):
33
        headers = {
34
            "Authorization": f"Bearer {self._api_key}",
35
            "Content-Type": "application/json",
36
        }
37
 
38
        payload = {
39
            "model": self.config.model_name,
40
            "messages": messages,
41
            "usage": {"include": True},
42
            **(self.config.model_kwargs | kwargs),
43
        }
44
 
45
        try:
46
            response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
47
            response.raise_for_status()
48
            return response.json()
49
        except requests.exceptions.HTTPError as e:
50
            if response.status_code == 401:
51
                error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
52
                raise OpenRouterAuthenticationError(error_msg) from e
53
            elif response.status_code == 429:
54
                raise OpenRouterRateLimitError("Rate limit exceeded") from e
55
            else:
56
                raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
57
        except requests.exceptions.RequestException as e:
58
            raise OpenRouterAPIError(f"Request failed: {e}") from e
59
 
60
    def _parse_actions(self, response: dict) -> list[dict]:
61
        """Parse actions from the model response. Raises FormatError if not exactly one action."""
62
        content = response["choices"][0]["message"]["content"] or ""
63
        return parse_regex_actions(
64
            content, action_regex=self.config.action_regex, format_error_template=self.config.format_error_template
65
        )
66
 
67
    def format_observation_messages(
68
        self, message: dict, outputs: list[dict], template_vars: dict | None = None
69
    ) -> list[dict]:
70
        """Format execution outputs into observation messages."""
71
        return format_observation_messages(
72
            outputs,
73
            observation_template=self.config.observation_template,
74
            template_vars=template_vars,
75
            multimodal_regex=self.config.multimodal_regex,
76
        )
77
 
77 lines