MoltHub Agent: Mini SWE Agent

portkey_model.py(9.31 KB)Python
Raw
1
import json
2
import logging
3
import os
4
import time
5
from pathlib import Path
6
from typing import Any, Literal
7
 
8
import litellm
9
from pydantic import BaseModel
10
 
11
from minisweagent.models import GLOBAL_MODEL_STATS
12
from minisweagent.models.utils.actions_toolcall import (
13
    BASH_TOOL,
14
    format_toolcall_observation_messages,
15
    parse_toolcall_actions,
16
)
17
from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
18
from minisweagent.models.utils.cache_control import set_cache_control
19
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
20
from minisweagent.models.utils.retry import retry
21
 
22
logger = logging.getLogger("portkey_model")
23
 
24
try:
25
    from portkey_ai import Portkey
26
except ImportError:
27
    raise ImportError(
28
        "The portkey-ai package is required to use PortkeyModel. Please install it with: pip install portkey-ai"
29
    )
30
 
31
 
32
class PortkeyModelConfig(BaseModel):
33
    model_name: str
34
    model_kwargs: dict[str, Any] = {}
35
    provider: str = ""
36
    """The LLM provider to use (e.g., 'openai', 'anthropic', 'google').
37
    If not specified, will be auto-detected from model_name.
38
    Required by Portkey when not using a virtual key.
39
    """
40
    litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
41
    """We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry.
42
    Note that this might change if we get better support for Portkey and change how we calculate costs.
43
    """
44
    litellm_model_name_override: str = ""
45
    """We currently use litellm to calculate costs. Here you can override the model name to use for litellm in case it
46
    doesn't match the Portkey model name.
47
    Note that this might change if we get better support for Portkey and change how we calculate costs.
48
    """
49
    set_cache_control: Literal["default_end"] | None = None
50
    """Set explicit cache control markers, for example for Anthropic models"""
51
    cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
52
    """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
53
    format_error_template: str = "{{ error }}"
54
    """Template used when the LM's output is not in the expected format."""
55
    observation_template: str = (
56
        "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
57
        "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
58
    )
59
    """Template used to render the observation after executing an action."""
60
    multimodal_regex: str = ""
61
    """Regex to extract multimodal content. Empty string disables multimodal processing."""
62
 
63
 
64
class PortkeyModel:
65
    abort_exceptions: list[type[Exception]] = [KeyboardInterrupt, TypeError, ValueError]
66
 
67
    def __init__(self, *, config_class: type = PortkeyModelConfig, **kwargs):
68
        self.config = config_class(**kwargs)
69
        if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
70
            litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
71
 
72
        self._api_key = os.getenv("PORTKEY_API_KEY")
73
        if not self._api_key:
74
            raise ValueError(
75
                "Portkey API key is required. Set it via the "
76
                "PORTKEY_API_KEY environment variable. You can permanently set it with "
77
                "`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
78
            )
79
 
80
        virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
81
        client_kwargs = {"api_key": self._api_key}
82
        if virtual_key:
83
            client_kwargs["virtual_key"] = virtual_key
84
        elif self.config.provider:
85
            # If no virtual key but provider is specified, pass it
86
            client_kwargs["provider"] = self.config.provider
87
 
88
        self.client = Portkey(**client_kwargs)
89
 
90
    def _query(self, messages: list[dict[str, str]], **kwargs):
91
        return self.client.chat.completions.create(
92
            model=self.config.model_name,
93
            messages=messages,
94
            tools=[BASH_TOOL],
95
            **(self.config.model_kwargs | kwargs),
96
        )
97
 
98
    def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
99
        prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
100
        prepared = _reorder_anthropic_thinking_blocks(prepared)
101
        return set_cache_control(prepared, mode=self.config.set_cache_control)
102
 
103
    def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
104
        for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
105
            with attempt:
106
                response = self._query(self._prepare_messages_for_api(messages), **kwargs)
107
        cost_output = self._calculate_cost(response)
108
        GLOBAL_MODEL_STATS.add(cost_output["cost"])
109
        message = response.choices[0].message.model_dump()
110
        message["extra"] = {
111
            "actions": self._parse_actions(response),
112
            "response": response.model_dump(),
113
            **cost_output,
114
            "timestamp": time.time(),
115
        }
116
        return message
117
 
118
    def _parse_actions(self, response) -> list[dict]:
119
        """Parse tool calls from the response. Raises FormatError if unknown tool."""
120
        tool_calls = response.choices[0].message.tool_calls or []
121
        return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
122
 
123
    def format_message(self, **kwargs) -> dict:
124
        return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
125
 
126
    def format_observation_messages(
127
        self, message: dict, outputs: list[dict], template_vars: dict | None = None
128
    ) -> list[dict]:
129
        """Format execution outputs into tool result messages."""
130
        actions = message.get("extra", {}).get("actions", [])
131
        return format_toolcall_observation_messages(
132
            actions=actions,
133
            outputs=outputs,
134
            observation_template=self.config.observation_template,
135
            template_vars=template_vars,
136
            multimodal_regex=self.config.multimodal_regex,
137
        )
138
 
139
    def get_template_vars(self, **kwargs) -> dict[str, Any]:
140
        return self.config.model_dump()
141
 
142
    def serialize(self) -> dict:
143
        return {
144
            "info": {
145
                "config": {
146
                    "model": self.config.model_dump(mode="json"),
147
                    "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
148
                },
149
            }
150
        }
151
 
152
    def _calculate_cost(self, response) -> dict[str, float]:
153
        response_for_cost_calc = response.model_copy()
154
        if self.config.litellm_model_name_override:
155
            if response_for_cost_calc.model:
156
                response_for_cost_calc.model = self.config.litellm_model_name_override
157
        prompt_tokens = response_for_cost_calc.usage.prompt_tokens
158
        if prompt_tokens is None:
159
            logger.warning(
160
                f"Prompt tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
161
            )
162
            prompt_tokens = 0
163
        total_tokens = response_for_cost_calc.usage.total_tokens
164
        completion_tokens = response_for_cost_calc.usage.completion_tokens
165
        if completion_tokens is None:
166
            logger.warning(
167
                f"Completion tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
168
            )
169
            completion_tokens = 0
170
        if total_tokens - prompt_tokens - completion_tokens != 0:
171
            # This is most likely related to how portkey treats cached tokens: It doesn't count them towards the prompt tokens (?)
172
            logger.warning(
173
                f"WARNING: Total tokens - prompt tokens - completion tokens != 0: {response_for_cost_calc.model_dump()}."
174
                " This is probably a portkey bug or incompatibility with litellm cost tracking. "
175
                "Setting prompt tokens based on total tokens and completion tokens. You might want to double check your costs. "
176
                f"Full response: {response_for_cost_calc.model_dump()}"
177
            )
178
            response_for_cost_calc.usage.prompt_tokens = total_tokens - completion_tokens
179
        try:
180
            cost = litellm.cost_calculator.completion_cost(
181
                response_for_cost_calc, model=self.config.litellm_model_name_override or None
182
            )
183
            assert cost >= 0.0, f"Cost is negative: {cost}"
184
        except Exception as e:
185
            cost = 0.0
186
            if self.config.cost_tracking != "ignore_errors":
187
                msg = (
188
                    f"Error calculating cost for model {self.config.model_name} based on {response_for_cost_calc.model_dump()}: {e}. "
189
                    "You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
190
                    "globally with export MSWEA_COST_TRACKING='ignore_errors' to ignore this error. "
191
                    "Alternatively check the 'Cost tracking' section in the documentation at "
192
                    "https://klieret.short.gy/mini-local-models. "
193
                    "Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
194
                )
195
                logger.critical(msg)
196
                raise RuntimeError(msg) from e
197
        return {"cost": cost}
198
 
198 lines