| 1 | import json
|
| 2 | import logging
|
| 3 | import os
|
| 4 | import time
|
| 5 | from pathlib import Path
|
| 6 | from typing import Any, Literal
|
| 7 |
|
| 8 | import litellm
|
| 9 | from pydantic import BaseModel
|
| 10 |
|
| 11 | from minisweagent.models import GLOBAL_MODEL_STATS
|
| 12 | from minisweagent.models.utils.actions_toolcall import (
|
| 13 | BASH_TOOL,
|
| 14 | format_toolcall_observation_messages,
|
| 15 | parse_toolcall_actions,
|
| 16 | )
|
| 17 | from minisweagent.models.utils.anthropic_utils import _reorder_anthropic_thinking_blocks
|
| 18 | from minisweagent.models.utils.cache_control import set_cache_control
|
| 19 | from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
|
| 20 | from minisweagent.models.utils.retry import retry
|
| 21 |
|
| 22 | logger = logging.getLogger("portkey_model")
|
| 23 |
|
| 24 | try:
|
| 25 | from portkey_ai import Portkey
|
| 26 | except ImportError:
|
| 27 | raise ImportError(
|
| 28 | "The portkey-ai package is required to use PortkeyModel. Please install it with: pip install portkey-ai"
|
| 29 | )
|
| 30 |
|
| 31 |
|
| 32 | class PortkeyModelConfig(BaseModel):
|
| 33 | model_name: str
|
| 34 | model_kwargs: dict[str, Any] = {}
|
| 35 | provider: str = ""
|
| 36 | """The LLM provider to use (e.g., 'openai', 'anthropic', 'google').
|
| 37 | If not specified, will be auto-detected from model_name.
|
| 38 | Required by Portkey when not using a virtual key.
|
| 39 | """
|
| 40 | litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH")
|
| 41 | """We currently use litellm to calculate costs. Here you can register additional models to litellm's model registry.
|
| 42 | Note that this might change if we get better support for Portkey and change how we calculate costs.
|
| 43 | """
|
| 44 | litellm_model_name_override: str = ""
|
| 45 | """We currently use litellm to calculate costs. Here you can override the model name to use for litellm in case it
|
| 46 | doesn't match the Portkey model name.
|
| 47 | Note that this might change if we get better support for Portkey and change how we calculate costs.
|
| 48 | """
|
| 49 | set_cache_control: Literal["default_end"] | None = None
|
| 50 | """Set explicit cache control markers, for example for Anthropic models"""
|
| 51 | cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default")
|
| 52 | """Cost tracking mode for this model. Can be "default" or "ignore_errors" (ignore errors/missing cost info)"""
|
| 53 | format_error_template: str = "{{ error }}"
|
| 54 | """Template used when the LM's output is not in the expected format."""
|
| 55 | observation_template: str = (
|
| 56 | "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
| 57 | "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
| 58 | )
|
| 59 | """Template used to render the observation after executing an action."""
|
| 60 | multimodal_regex: str = ""
|
| 61 | """Regex to extract multimodal content. Empty string disables multimodal processing."""
|
| 62 |
|
| 63 |
|
| 64 | class PortkeyModel:
|
| 65 | abort_exceptions: list[type[Exception]] = [KeyboardInterrupt, TypeError, ValueError]
|
| 66 |
|
| 67 | def __init__(self, *, config_class: type = PortkeyModelConfig, **kwargs):
|
| 68 | self.config = config_class(**kwargs)
|
| 69 | if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file():
|
| 70 | litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text()))
|
| 71 |
|
| 72 | self._api_key = os.getenv("PORTKEY_API_KEY")
|
| 73 | if not self._api_key:
|
| 74 | raise ValueError(
|
| 75 | "Portkey API key is required. Set it via the "
|
| 76 | "PORTKEY_API_KEY environment variable. You can permanently set it with "
|
| 77 | "`mini-extra config set PORTKEY_API_KEY YOUR_KEY`."
|
| 78 | )
|
| 79 |
|
| 80 | virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY")
|
| 81 | client_kwargs = {"api_key": self._api_key}
|
| 82 | if virtual_key:
|
| 83 | client_kwargs["virtual_key"] = virtual_key
|
| 84 | elif self.config.provider:
|
| 85 | # If no virtual key but provider is specified, pass it
|
| 86 | client_kwargs["provider"] = self.config.provider
|
| 87 |
|
| 88 | self.client = Portkey(**client_kwargs)
|
| 89 |
|
| 90 | def _query(self, messages: list[dict[str, str]], **kwargs):
|
| 91 | return self.client.chat.completions.create(
|
| 92 | model=self.config.model_name,
|
| 93 | messages=messages,
|
| 94 | tools=[BASH_TOOL],
|
| 95 | **(self.config.model_kwargs | kwargs),
|
| 96 | )
|
| 97 |
|
| 98 | def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]:
|
| 99 | prepared = [{k: v for k, v in msg.items() if k != "extra"} for msg in messages]
|
| 100 | prepared = _reorder_anthropic_thinking_blocks(prepared)
|
| 101 | return set_cache_control(prepared, mode=self.config.set_cache_control)
|
| 102 |
|
| 103 | def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
| 104 | for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions):
|
| 105 | with attempt:
|
| 106 | response = self._query(self._prepare_messages_for_api(messages), **kwargs)
|
| 107 | cost_output = self._calculate_cost(response)
|
| 108 | GLOBAL_MODEL_STATS.add(cost_output["cost"])
|
| 109 | message = response.choices[0].message.model_dump()
|
| 110 | message["extra"] = {
|
| 111 | "actions": self._parse_actions(response),
|
| 112 | "response": response.model_dump(),
|
| 113 | **cost_output,
|
| 114 | "timestamp": time.time(),
|
| 115 | }
|
| 116 | return message
|
| 117 |
|
| 118 | def _parse_actions(self, response) -> list[dict]:
|
| 119 | """Parse tool calls from the response. Raises FormatError if unknown tool."""
|
| 120 | tool_calls = response.choices[0].message.tool_calls or []
|
| 121 | return parse_toolcall_actions(tool_calls, format_error_template=self.config.format_error_template)
|
| 122 |
|
| 123 | def format_message(self, **kwargs) -> dict:
|
| 124 | return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
| 125 |
|
| 126 | def format_observation_messages(
|
| 127 | self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
| 128 | ) -> list[dict]:
|
| 129 | """Format execution outputs into tool result messages."""
|
| 130 | actions = message.get("extra", {}).get("actions", [])
|
| 131 | return format_toolcall_observation_messages(
|
| 132 | actions=actions,
|
| 133 | outputs=outputs,
|
| 134 | observation_template=self.config.observation_template,
|
| 135 | template_vars=template_vars,
|
| 136 | multimodal_regex=self.config.multimodal_regex,
|
| 137 | )
|
| 138 |
|
| 139 | def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
| 140 | return self.config.model_dump()
|
| 141 |
|
| 142 | def serialize(self) -> dict:
|
| 143 | return {
|
| 144 | "info": {
|
| 145 | "config": {
|
| 146 | "model": self.config.model_dump(mode="json"),
|
| 147 | "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
| 148 | },
|
| 149 | }
|
| 150 | }
|
| 151 |
|
| 152 | def _calculate_cost(self, response) -> dict[str, float]:
|
| 153 | response_for_cost_calc = response.model_copy()
|
| 154 | if self.config.litellm_model_name_override:
|
| 155 | if response_for_cost_calc.model:
|
| 156 | response_for_cost_calc.model = self.config.litellm_model_name_override
|
| 157 | prompt_tokens = response_for_cost_calc.usage.prompt_tokens
|
| 158 | if prompt_tokens is None:
|
| 159 | logger.warning(
|
| 160 | f"Prompt tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
|
| 161 | )
|
| 162 | prompt_tokens = 0
|
| 163 | total_tokens = response_for_cost_calc.usage.total_tokens
|
| 164 | completion_tokens = response_for_cost_calc.usage.completion_tokens
|
| 165 | if completion_tokens is None:
|
| 166 | logger.warning(
|
| 167 | f"Completion tokens are None for model {self.config.model_name}. Setting to 0. Full response: {response_for_cost_calc.model_dump()}"
|
| 168 | )
|
| 169 | completion_tokens = 0
|
| 170 | if total_tokens - prompt_tokens - completion_tokens != 0:
|
| 171 | # This is most likely related to how portkey treats cached tokens: It doesn't count them towards the prompt tokens (?)
|
| 172 | logger.warning(
|
| 173 | f"WARNING: Total tokens - prompt tokens - completion tokens != 0: {response_for_cost_calc.model_dump()}."
|
| 174 | " This is probably a portkey bug or incompatibility with litellm cost tracking. "
|
| 175 | "Setting prompt tokens based on total tokens and completion tokens. You might want to double check your costs. "
|
| 176 | f"Full response: {response_for_cost_calc.model_dump()}"
|
| 177 | )
|
| 178 | response_for_cost_calc.usage.prompt_tokens = total_tokens - completion_tokens
|
| 179 | try:
|
| 180 | cost = litellm.cost_calculator.completion_cost(
|
| 181 | response_for_cost_calc, model=self.config.litellm_model_name_override or None
|
| 182 | )
|
| 183 | assert cost >= 0.0, f"Cost is negative: {cost}"
|
| 184 | except Exception as e:
|
| 185 | cost = 0.0
|
| 186 | if self.config.cost_tracking != "ignore_errors":
|
| 187 | msg = (
|
| 188 | f"Error calculating cost for model {self.config.model_name} based on {response_for_cost_calc.model_dump()}: {e}. "
|
| 189 | "You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or "
|
| 190 | "globally with export MSWEA_COST_TRACKING='ignore_errors' to ignore this error. "
|
| 191 | "Alternatively check the 'Cost tracking' section in the documentation at "
|
| 192 | "https://klieret.short.gy/mini-local-models. "
|
| 193 | "Still stuck? Please open a github issue at https://github.com/SWE-agent/mini-swe-agent/issues/new/choose!"
|
| 194 | )
|
| 195 | logger.critical(msg)
|
| 196 | raise RuntimeError(msg) from e
|
| 197 | return {"cost": cost}
|
| 198 |
|