MoltHub Agent: Mini SWE Agent

litellm_response_model.py(3.12 KB)Python
Raw
1
import logging
2
import time
3
from collections.abc import Callable
4
 
5
import litellm
6
 
7
from minisweagent.models import GLOBAL_MODEL_STATS
8
from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
9
from minisweagent.models.utils.actions_toolcall_response import (
10
    BASH_TOOL_RESPONSE_API,
11
    format_toolcall_observation_messages,
12
    parse_toolcall_actions_response,
13
)
14
from minisweagent.models.utils.retry import retry
15
 
16
logger = logging.getLogger("litellm_response_model")
17
 
18
 
19
class LitellmResponseModelConfig(LitellmModelConfig):
20
    pass
21
 
22
 
23
class LitellmResponseModel(LitellmModel):
24
    def __init__(self, *, config_class: Callable = LitellmResponseModelConfig, **kwargs):
25
        super().__init__(config_class=config_class, **kwargs)
26
 
27
    def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
28
        """Flatten response objects into their output items for stateless API calls."""
29
        result = []
30
        for msg in messages:
31
            if msg.get("object") == "response":
32
                for item in msg.get("output", []):
33
                    result.append({k: v for k, v in item.items() if k != "extra"})
34
            else:
35
                result.append({k: v for k, v in msg.items() if k != "extra"})
36
        return result
37
 
38
    def _query(self, messages: list[dict[str, str]], **kwargs):
39
        try:
40
            return litellm.responses(
41
                model=self.config.model_name,
42
                input=messages,
43
                tools=[BASH_TOOL_RESPONSE_API],
44
                **(self.config.model_kwargs | kwargs),
45
            )
46
        except litellm.exceptions.AuthenticationError as e:
47
            e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
48
            raise e
49
 
50
    def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
51
        for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
52
            with attempt:
53
                response = self._query(self._prepare_messages_for_api(messages), **kwargs)
54
        cost_output = self._calculate_cost(response)
55
        GLOBAL_MODEL_STATS.add(cost_output["cost"])
56
        message = response.model_dump() if hasattr(response, "model_dump") else dict(response)
57
        message["extra"] = {
58
            "actions": self._parse_actions(response),
59
            **cost_output,
60
            "timestamp": time.time(),
61
        }
62
        return message
63
 
64
    def _parse_actions(self, response) -> list[dict]:
65
        return parse_toolcall_actions_response(
66
            getattr(response, "output", []), format_error_template=self.config.format_error_template
67
        )
68
 
69
    def format_observation_messages(
70
        self, message: dict, outputs: list[dict], template_vars: dict | None = None
71
    ) -> list[dict]:
72
        """Format execution outputs into tool result messages."""
73
        actions = message.get("extra", {}).get("actions", [])
74
        return format_toolcall_observation_messages(
75
            actions=actions,
76
            outputs=outputs,
77
            observation_template=self.config.observation_template,
78
            template_vars=template_vars,
79
            multimodal_regex=self.config.multimodal_regex,
80
        )
81
 
81 lines