| 1 | import json
|
| 2 | import logging
|
| 3 | import os
|
| 4 | import time
|
| 5 | from typing import Any, Literal
|
| 6 |
|
| 7 | import requests
|
| 8 | from pydantic import BaseModel
|
| 9 |
|
| 10 | from minisweagent.models import GLOBAL_MODEL_STATS
|
| 11 | from minisweagent.models.utils.actions_toolcall import (
|
| 12 | BASH_TOOL,
|
| 13 | format_toolcall_observation_messages,
|
| 14 | parse_toolcall_actions,
|
| 15 | )
|
| 16 | from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
|
| 17 | from minisweagent.models.utils.cache_control import set_cache_control
|
| 18 | from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
|
| 19 | from minisweagent.models.utils.retry import retry
|
| 20 |
|
| 21 | logger = logging.getLogger("openrouter_model")
|
| 22 |
|
| 23 |
|
| 24 | class OpenRouterModelConfig(BaseModel):
|
| 25 | model_name: str
|
| 26 | model_kwargs: dict[str, Any] = {}
|
| 27 | set_cache_control: Literal["default_end"] | None = None
|
| 28 | """Set explicit cache control markers, for example for Anthropic models"""
|
| 29 | cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
| 30 | """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
|
| 31 | format_error_template: str = "{{ error }}"
|
| 32 | """Template used when the LM's output is not in the expected format."""
|
| 33 | observation_template: str = (
|
| 34 | "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
| 35 | "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
| 36 | )
|
| 37 | """Template used to render the observation after executing an action."""
|
| 38 | multimodal_regex: str = ""
|
| 39 | """Regex to extract multimodal content. Empty string disables multimodal processing."""
|
| 40 |
|
| 41 |
|
| 42 | class OpenRouterAPIError(Exception):
|
| 43 | """Custom exception for OpenRouter API errors."""
|
| 44 |
|
| 45 |
|
| 46 | class OpenRouterAuthenticationError(Exception):
|
| 47 | """Custom exception for OpenRouter authentication errors."""
|
| 48 |
|
| 49 |
|
| 50 | class OpenRouterRateLimitError(Exception):
|
| 51 | """Custom exception for OpenRouter rate limit errors."""
|
| 52 |
|
| 53 |
|
| 54 | class OpenRouterModel:
|
| 55 | abort_exceptions: list[type[Exception]] = [OpenRouterAuthenticationError, KeyboardInterrupt]
|
| 56 |
|
| 57 | def __init__(self, **kwargs):
|
| 58 | self.config = OpenRouterModelConfig(**kwargs)
|
| 59 | self._api_url = "https://openrouter.ai/api/v1/chat/completions"
|
| 60 | self._api_key = os.getenv("OPENROUTER_API_KEY", "")
|
| 61 |
|
| 62 | def _query(self, messages: list[dict[str, str]], **kwargs):
|
| 63 | headers = {
|
| 64 | "Authorization": f"Bearer {self._api_key}",
|
| 65 | "Content-Type": "application/json",
|
| 66 | }
|
| 67 |
|
| 68 | payload = {
|
| 69 | "model": self.config.model_name,
|
| 70 | "messages": messages,
|
| 71 | "tools": [BASH_TOOL],
|
| 72 | "usage": {"include": True},
|
| 73 | **(self.config.model_kwargs | kwargs),
|
| 74 | }
|
| 75 |
|
| 76 | try:
|
| 77 | response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
|
| 78 | response.raise_for_status()
|
| 79 | return response.json()
|
| 80 | except requests.exceptions.HTTPError as e:
|
| 81 | if response.status_code == 401:
|
| 82 | error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
|
| 83 | raise OpenRouterAuthenticationError(error_msg) from e
|
| 84 | elif response.status_code == 429:
|
| 85 | raise OpenRouterRateLimitError("Rate limit exceeded") from e
|
| 86 | else:
|
| 87 | raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
|
| 88 | except requests.exceptions.RequestException as e:
|
| 89 | raise OpenRouterAPIError(f"Request failed: {e}") from e
|
| 90 |
|
| 91 | def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
| 92 | prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
|
| 93 | prepared = _reorder_anthropic_thinking_blocks(prepared)
|
| 94 | return set_cache_control(prepared, mode=self.config.set_cache_control)
|
| 95 |
|
| 96 | def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
| 97 | for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
| 98 | with attempt:
|
| 99 | response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
| 100 | cost_output = self._calculate_cost(response)
|
| 101 | GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
| 102 | message = dict(response["choices"][0]["message"])
|
| 103 | message["extra"] = {
|
| 104 | "actions": self._parse_actions(response),
|
| 105 | "response": response,
|
| 106 | **cost_output,
|
| 107 | "timestamp": time.time(),
|
| 108 | }
|
| 109 | return message
|
| 110 |
|
| 111 | def _calculate_cost(self, response) -> dict[str, float]:
|
| 112 | usage = response.get("usage", {})
|
| 113 | cost = usage.get("cost", 0.0)
|
| 114 | if cost <= 0.0 and self.config.cost_tracking != "ignore_errors":
|
| 115 | raise RuntimeError(
|
| 116 | f"No valid cost information available from OpenRouter API for model {self.config.model_name}: "
|
| 117 | f"Usage {usage}, cost {cost}. Cost must be > 0.0. Set cost_tracking: 'ignore_errors' in your config file or "
|
| 118 | "export MSWEA_COST_TRACKING='ignore_errors' to ignore cost tracking errors "
|
| 119 | "(for example for free/local models), more information at https://klieret.short.gy/mini-local-models "
|
| 120 | "for more details. Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
|
| 121 | )
|
| 122 | return {"cost": cost}
|
| 123 |
|
| 124 | def _parse_actions(self, response: dict) -> list[dict]:
|
| 125 | """Parse tool calls from the response. Raises FormatError if unknown tool."""
|
| 126 | tool_calls = response["choices"][0]["message"].get("tool_calls") or []
|
| 127 | tool_calls = [_DictToObj(tc) for tc in tool_calls]
|
| 128 | return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
|
| 129 |
|
| 130 | def format_message(self, **kwargs) -> dict:
|
| 131 | return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
| 132 |
|
| 133 | def format_observation_messages(
|
| 134 | self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
| 135 | ) -> list[dict]:
|
| 136 | """Format execution outputs into tool result messages."""
|
| 137 | actions = message.get("extra", {}).get("actions", [])
|
| 138 | return format_toolcall_observation_messages(
|
| 139 | actions=actions,
|
| 140 | outputs=outputs,
|
| 141 | observation_template=self.config.observation_template,
|
| 142 | template_vars=template_vars,
|
| 143 | multimodal_regex=self.config.multimodal_regex,
|
| 144 | )
|
| 145 |
|
| 146 | def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
| 147 | return self.config.model_dump()
|
| 148 |
|
| 149 | def serialize(self) -> dict:
|
| 150 | return {
|
| 151 | "info": {
|
| 152 | "config": {
|
| 153 | "model": self.config.model_dump(mode="json"),
|
| 154 | "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
| 155 | },
|
| 156 | }
|
| 157 | }
|
| 158 |
|
| 159 |
|
| 160 | class _DictToObj:
|
| 161 | """Simple wrapper to convert dict to object with attribute access."""
|
| 162 |
|
| 163 | def __init__(self, d: dict):
|
| 164 | self._d = d
|
| 165 | self.id = d.get("id")
|
| 166 | self.function = _DictToObj(d.get("function", {})) if "function" in d else None
|
| 167 | self.name = d.get("name")
|
| 168 | self.arguments = d.get("arguments")
|
| 169 |
|