| 1 | import logging
|
| 2 | import time
|
| 3 | from collections.abc import Callable
|
| 4 |
|
| 5 | import litellm
|
| 6 |
|
| 7 | from minisweagent.models import GLOBAL_MODEL_STATS
|
| 8 | from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
|
| 9 | from minisweagent.models.utils.actions_toolcall_response import (
|
| 10 | BASH_TOOL_RESPONSE_API,
|
| 11 | format_toolcall_observation_messages,
|
| 12 | parse_toolcall_actions_response,
|
| 13 | )
|
| 14 | from minisweagent.models.utils.retry import retry
|
| 15 |
|
| 16 | logger = logging.getLogger("litellm_response_model")
|
| 17 |
|
| 18 |
|
| 19 | class LitellmResponseModelConfig(LitellmModelConfig):
|
| 20 | pass
|
| 21 |
|
| 22 |
|
| 23 | class LitellmResponseModel(LitellmModel):
|
| 24 | def __init__(self, *, config_class: Callable = LitellmResponseModelConfig, **kwargs):
|
| 25 | super().__init__(config_class=config_class, **kwargs)
|
| 26 |
|
| 27 | def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
| 28 | """Flatten response objects into their output items for stateless API calls."""
|
| 29 | result = []
|
| 30 | for msg in messages:
|
| 31 | if msg.get("object") == "response":
|
| 32 | for item in msg.get("output", []):
|
| 33 | result.append({k: v for k, v in item.items() if k != "extra"})
|
| 34 | else:
|
| 35 | result.append({k: v for k, v in msg.items() if k != "extra"})
|
| 36 | return result
|
| 37 |
|
| 38 | def _query(self, messages: list[dict[str, str]], **kwargs):
|
| 39 | try:
|
| 40 | return litellm.responses(
|
| 41 | model=self.config.model_name,
|
| 42 | input=messages,
|
| 43 | tools=[BASH_TOOL_RESPONSE_API],
|
| 44 | **(self.config.model_kwargs | kwargs),
|
| 45 | )
|
| 46 | except litellm.exceptions.AuthenticationError as e:
|
| 47 | e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
|
| 48 | raise e
|
| 49 |
|
| 50 | def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
| 51 | for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
| 52 | with attempt:
|
| 53 | response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
| 54 | cost_output = self._calculate_cost(response)
|
| 55 | GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
| 56 | message = response.model_dump() if hasattr(response, "model_dump") else dict(response)
|
| 57 | message["extra"] = {
|
| 58 | "actions": self._parse_actions(response),
|
| 59 | **cost_output,
|
| 60 | "timestamp": time.time(),
|
| 61 | }
|
| 62 | return message
|
| 63 |
|
| 64 | def _parse_actions(self, response) -> list[dict]:
|
| 65 | return parse_toolcall_actions_response(
|
| 66 | getattr(response, "output", []), format_error_template=self.config.format_error_template
|
| 67 | )
|
| 68 |
|
| 69 | def format_observation_messages(
|
| 70 | self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
| 71 | ) -> list[dict]:
|
| 72 | """Format execution outputs into tool result messages."""
|
| 73 | actions = message.get("extra", {}).get("actions", [])
|
| 74 | return format_toolcall_observation_messages(
|
| 75 | actions=actions,
|
| 76 | outputs=outputs,
|
| 77 | observation_template=self.config.observation_template,
|
| 78 | template_vars=template_vars,
|
| 79 | multimodal_regex=self.config.multimodal_regex,
|
| 80 | )
|
| 81 |
|