| 1 | """This file provides convenience functions for selecting models.
|
| 2 | You can ignore this file completely if you explicitly set your model in your run script.
|
| 3 | """
|
| 4 |
|
| 5 | import copy
|
| 6 | import importlib
|
| 7 | import os
|
| 8 | import threading
|
| 9 |
|
| 10 | from minisweagent import Model
|
| 11 |
|
| 12 |
|
| 13 | class GlobalModelStats:
|
| 14 | """Global model statistics tracker with optional limits."""
|
| 15 |
|
| 16 | def __init__(self):
|
| 17 | self._cost = 0.0
|
| 18 | self._n_calls = 0
|
| 19 | self._lock = threading.Lock()
|
| 20 | self.cost_limit = float(os.getenv("MSWEA_GLOBAL_COST_LIMIT", "0"))
|
| 21 | self.call_limit = int(os.getenv("MSWEA_GLOBAL_CALL_LIMIT", "0"))
|
| 22 | if (self.cost_limit > 0 or self.call_limit > 0) and not os.getenv("MSWEA_SILENT_STARTUP"):
|
| 23 | print(f"Global cost/call limit: ${self.cost_limit:.4f} / {self.call_limit}")
|
| 24 |
|
| 25 | def add(self, cost: float) -> None:
|
| 26 | """Add a model call with its cost, checking limits."""
|
| 27 | with self._lock:
|
| 28 | self._cost += cost
|
| 29 | self._n_calls += 1
|
| 30 | if 0 < self.cost_limit < self._cost or 0 < self.call_limit < self._n_calls + 1:
|
| 31 | raise RuntimeError(f"Global cost/call limit exceeded: ${self._cost:.4f} / {self._n_calls}")
|
| 32 |
|
| 33 | @property
|
| 34 | def cost(self) -> float:
|
| 35 | return self._cost
|
| 36 |
|
| 37 | @property
|
| 38 | def n_calls(self) -> int:
|
| 39 | return self._n_calls
|
| 40 |
|
| 41 |
|
| 42 | GLOBAL_MODEL_STATS = GlobalModelStats()
|
| 43 |
|
| 44 |
|
| 45 | def get_model(input_model_name: str | None = None, config: dict | None = None) -> Model:
|
| 46 | """Get an initialized model object from any kind of user input or settings."""
|
| 47 | resolved_model_name = get_model_name(input_model_name, config)
|
| 48 | if config is None:
|
| 49 | config = {}
|
| 50 | config = copy.deepcopy(config)
|
| 51 | config["model_name"] = resolved_model_name
|
| 52 |
|
| 53 | model_class = get_model_class(resolved_model_name, config.pop("model_class", ""))
|
| 54 |
|
| 55 | if (
|
| 56 | any(s in resolved_model_name.lower() for s in ["anthropic", "sonnet", "opus", "claude"])
|
| 57 | and "set_cache_control" not in config
|
| 58 | ):
|
| 59 | # Select cache control for Anthropic models by default
|
| 60 | config["set_cache_control"] = "default_end"
|
| 61 |
|
| 62 | return model_class(**config)
|
| 63 |
|
| 64 |
|
| 65 | def get_model_name(input_model_name: str | None = None, config: dict | None = None) -> str:
|
| 66 | """Get a model name from any kind of user input or settings."""
|
| 67 | if config is None:
|
| 68 | config = {}
|
| 69 | if input_model_name:
|
| 70 | return input_model_name
|
| 71 | if from_config := config.get("model_name"):
|
| 72 | return from_config
|
| 73 | if from_env := os.getenv("MSWEA_MODEL_NAME"):
|
| 74 | return from_env
|
| 75 | raise ValueError("No default model set. Please run `mini-extra config setup` to set one.")
|
| 76 |
|
| 77 |
|
| 78 | _MODEL_CLASS_MAPPING = {
|
| 79 | "litellm": "minisweagent.models.litellm_model.LitellmModel",
|
| 80 | "litellm_textbased": "minisweagent.models.litellm_textbased_model.LitellmTextbasedModel",
|
| 81 | "litellm_response": "minisweagent.models.litellm_response_model.LitellmResponseModel",
|
| 82 | "openrouter": "minisweagent.models.openrouter_model.OpenRouterModel",
|
| 83 | "openrouter_textbased": "minisweagent.models.openrouter_textbased_model.OpenRouterTextbasedModel",
|
| 84 | "openrouter_response": "minisweagent.models.openrouter_response_model.OpenRouterResponseModel",
|
| 85 | "portkey": "minisweagent.models.portkey_model.PortkeyModel",
|
| 86 | "portkey_response": "minisweagent.models.portkey_response_model.PortkeyResponseAPIModel",
|
| 87 | "requesty": "minisweagent.models.requesty_model.RequestyModel",
|
| 88 | "deterministic": "minisweagent.models.test_models.DeterministicModel",
|
| 89 | }
|
| 90 |
|
| 91 |
|
| 92 | def get_model_class(model_name: str, model_class: str = "") -> type:
|
| 93 | """Select the best model class.
|
| 94 |
|
| 95 | If a model_class is provided (as shortcut name, or as full import path,
|
| 96 | e.g., "anthropic" or "minisweagent.models.anthropic.AnthropicModel"),
|
| 97 | it takes precedence over the `model_name`.
|
| 98 | Otherwise, the model_name is used to select the best model class.
|
| 99 | """
|
| 100 | if model_class:
|
| 101 | full_path = _MODEL_CLASS_MAPPING.get(model_class, model_class)
|
| 102 | try:
|
| 103 | module_name, class_name = full_path.rsplit(".", 1)
|
| 104 | module = importlib.import_module(module_name)
|
| 105 | return getattr(module, class_name)
|
| 106 | except (ValueError, ImportError, AttributeError):
|
| 107 | msg = f"Unknown model class: {model_class} (resolved to {full_path}, available: {_MODEL_CLASS_MAPPING})"
|
| 108 | raise ValueError(msg)
|
| 109 |
|
| 110 | # Default to LitellmModel
|
| 111 | from minisweagent.models.litellm_model import LitellmModel
|
| 112 |
|
| 113 | return LitellmModel
|
| 114 |
|