| 1 | import json
|
| 2 | import logging
|
| 3 |
|
| 4 | import requests
|
| 5 |
|
| 6 | from minisweagent.models.openrouter_model import (
|
| 7 | OpenRouterAPIError,
|
| 8 | OpenRouterAuthenticationError,
|
| 9 | OpenRouterModel,
|
| 10 | OpenRouterModelConfig,
|
| 11 | OpenRouterRateLimitError,
|
| 12 | )
|
| 13 | from minisweagent.models.utils.actions_text import format_observation_messages, parse_regex_actions
|
| 14 |
|
| 15 | logger = logging.getLogger("openrouter_textbased_model")
|
| 16 |
|
| 17 |
|
| 18 | class OpenRouterTextbasedModelConfig(OpenRouterModelConfig):
|
| 19 | action_regex: str = r"```mswea_bash_command\s*\n(.*?)\n```"
|
| 20 | """Regex to extract the action from the LM's output."""
|
| 21 | format_error_template: str = (
|
| 22 | "Please always provide EXACTLY ONE action in triple backticks, found {{actions|length}} actions."
|
| 23 | )
|
| 24 | """Template used when the LM's output is not in the expected format."""
|
| 25 |
|
| 26 |
|
| 27 | class OpenRouterTextbasedModel(OpenRouterModel):
|
| 28 | def __init__(self, **kwargs):
|
| 29 | super().__init__(**kwargs)
|
| 30 | self.config = OpenRouterTextbasedModelConfig(**kwargs)
|
| 31 |
|
| 32 | def _query(self, messages: list[dict[str, str]], **kwargs):
|
| 33 | headers = {
|
| 34 | "Authorization": f"Bearer {self._api_key}",
|
| 35 | "Content-Type": "application/json",
|
| 36 | }
|
| 37 |
|
| 38 | payload = {
|
| 39 | "model": self.config.model_name,
|
| 40 | "messages": messages,
|
| 41 | "usage": {"include": True},
|
| 42 | **(self.config.model_kwargs | kwargs),
|
| 43 | }
|
| 44 |
|
| 45 | try:
|
| 46 | response = requests.post(self._api_url, headers=headers, data=json.dumps(payload), timeout=60)
|
| 47 | response.raise_for_status()
|
| 48 | return response.json()
|
| 49 | except requests.exceptions.HTTPError as e:
|
| 50 | if response.status_code == 401:
|
| 51 | error_msg = "Authentication failed. You can permanently set your API key with `mini-extra config set OPENROUTER_API_KEY YOUR_KEY`."
|
| 52 | raise OpenRouterAuthenticationError(error_msg) from e
|
| 53 | elif response.status_code == 429:
|
| 54 | raise OpenRouterRateLimitError("Rate limit exceeded") from e
|
| 55 | else:
|
| 56 | raise OpenRouterAPIError(f"HTTP {response.status_code}: {response.text}") from e
|
| 57 | except requests.exceptions.RequestException as e:
|
| 58 | raise OpenRouterAPIError(f"Request failed: {e}") from e
|
| 59 |
|
| 60 | def _parse_actions(self, response: dict) -> list[dict]:
|
| 61 | """Parse actions from the model response. Raises FormatError if not exactly one action."""
|
| 62 | content = response["choices"][0]["message"]["content"] or ""
|
| 63 | return parse_regex_actions(
|
| 64 | content, action_regex=self.config.action_regex, format_error_template=self.config.format_error_template
|
| 65 | )
|
| 66 |
|
| 67 | def format_observation_messages(
|
| 68 | self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
| 69 | ) -> list[dict]:
|
| 70 | """Format execution outputs into observation messages."""
|
| 71 | return format_observation_messages(
|
| 72 | outputs,
|
| 73 | observation_template=self.config.observation_template,
|
| 74 | template_vars=template_vars,
|
| 75 | multimodal_regex=self.config.multimodal_regex,
|
| 76 | )
|
| 77 |
|