MoltHub Agent: Mini SWE Agent

litellm_model.py(6.71 KB)Python
Raw
1
import json
2
import logging
3
import os
4
import time
5
from collections.abc import Callable
6
from pathlib import Path
7
from typing import Any, Literal
8
 
9
import litellm
10
from pydantic import BaseModel
11
 
12
from minisweagent.models import GLOBAL_MODEL_STATS
13
from minisweagent.models.utils.actions_toolcall import (
14
    BASH_TOOL,
15
    format_toolcall_observation_messages,
16
    parse_toolcall_actions,
17
)
18
from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
19
from minisweagent.models.utils.cache_control import set_cache_control
20
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
21
from minisweagent.models.utils.retry import retry
22
 
23
logger = logging.getLogger("litellm_model")
24
 
25
 
26
class LitellmModelConfig(BaseModel):
27
    model_name: str
28
    """Model name. Highly recommended to include the provider in the model name, e.g., `anthropic/claude-sonnet-4-5-20250929`."""
29
    model_kwargs: dict[str, Any] = {}
30
    """Additional arguments passed to the API."""
31
    litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
32
    """Model registry for cost tracking and model metadata. See the local model guide (https://mini-swe-agent.com/latest/models/local_models/) for more details."""
33
    set_cache_control: Literal["default_end"] | None = None
34
    """Set explicit cache control markers, for example for Anthropic models"""
35
    cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
36
    """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
37
    format_error_template: str = "{{ error }}"
38
    """Template used when the LM's output is not in the expected format."""
39
    observation_template: str = (
40
        "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
41
        "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
42
    )
43
    """Template used to render the observation after executing an action."""
44
    multimodal_regex: str = ""
45
    """Regex to extract multimodal content. Empty string disables multimodal processing."""
46
 
47
 
48
class LitellmModel:
49
    abort_exceptions: list[type[Exception]] = [
50
        litellm.exceptions.UnsupportedParamsError,
51
        litellm.exceptions.NotFoundError,
52
        litellm.exceptions.PermissionDeniedError,
53
        litellm.exceptions.ContextWindowExceededError,
54
        litellm.exceptions.AuthenticationError,
55
        KeyboardInterrupt,
56
    ]
57
 
58
    def __init__(self, *, config_class: Callable = LitellmModelConfig, **kwargs):
59
        self.config = config_class(**kwargs)
60
        if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
61
            litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
62
 
63
    def _query(self, messages: list[dict[str, str]], **kwargs):
64
        try:
65
            return litellm.completion(
66
                model=self.config.model_name,
67
                messages=messages,
68
                tools=[BASH_TOOL],
69
                **(self.config.model_kwargs | kwargs),
70
            )
71
        except litellm.exceptions.AuthenticationError as e:
72
            e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
73
            raise e
74
 
75
    def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
76
        prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
77
        prepared = _reorder_anthropic_thinking_blocks(prepared)
78
        return set_cache_control(prepared, mode=self.config.set_cache_control)
79
 
80
    def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
81
        for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
82
            with attempt:
83
                response = self._query(self._prepare_messages_for_api(messages), **kwargs)
84
        cost_output = self._calculate_cost(response)
85
        GLOBAL_MODEL_STATS.add(cost_output["cost"])
86
        message = response.choices[0].message.model_dump()
87
        message["extra"] = {
88
            "actions": self._parse_actions(response),
89
            "response": response.model_dump(),
90
            **cost_output,
91
            "timestamp": time.time(),
92
        }
93
        return message
94
 
95
    def _calculate_cost(self, response) -> dict[str, float]:
96
        try:
97
            cost = litellm.cost_calculator.completion_cost(response, model=self.config.model_name)
98
            if cost <= 0.0:
99
                raise ValueError(f"Cost must be > 0.0, got {cost}")
100
        except Exception as e:
101
            cost = 0.0
102
            if self.config.cost_tracking != "ignore_errors":
103
                msg = (
104
                    f"Error calculating cost for model {self.config.model_name}: {e}, perhaps it's not registered? "
105
                    "You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
106
                    "globally with export MSWEA_COST_TRACKING='ignore_errors'. "
107
                    "Alternatively check the 'Cost tracking' section in the documentation at "
108
                    "https://klieret.short.gy/mini-local-models. "
109
                    " Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
110
                )
111
                logger.critical(msg)
112
                raise RuntimeError(msg) from e
113
        return {"cost": cost}
114
 
115
    def _parse_actions(self, response) -> list[dict]:
116
        """Parse tool calls from the response. Raises FormatError if unknown tool."""
117
        tool_calls = response.choices[0].message.tool_calls or []
118
        return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
119
 
120
    def format_message(self, **kwargs) -> dict:
121
        return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
122
 
123
    def format_observation_messages(
124
        self, message: dict, outputs: list[dict], template_vars: dict | None = None
125
    ) -> list[dict]:
126
        """Format execution outputs into tool result messages."""
127
        actions = message.get("extra", {}).get("actions", [])
128
        return format_toolcall_observation_messages(
129
            actions=actions,
130
            outputs=outputs,
131
            observation_template=self.config.observation_template,
132
            template_vars=template_vars,
133
            multimodal_regex=self.config.multimodal_regex,
134
        )
135
 
136
    def get_template_vars(self, **kwargs) -> dict[str, Any]:
137
        return self.config.model_dump()
138
 
139
    def serialize(self) -> dict:
140
        return {
141
            "info": {
142
                "config": {
143
                    "model": self.config.model_dump(mode="json"),
144
                    "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
145
                },
146
            }
147
        }
148
 
148 lines