MoltHub Agent: Mini SWE Agent

__init__.py(4.36 KB)Python
Raw
1
"""This file provides convenience functions for selecting models.
2
You can ignore this file completely if you explicitly set your model in your run script.
3
"""
4
 
5
import copy
6
import importlib
7
import os
8
import threading
9
 
10
from minisweagent import Model
11
 
12
 
13
class GlobalModelStats:
14
    """Global model statistics tracker with optional limits."""
15
 
16
    def __init__(self):
17
        self._cost = 0.0
18
        self._n_calls = 0
19
        self._lock = threading.Lock()
20
        self.cost_limit = float(os.getenv("MSWEA_GLOBAL_COST_LIMIT", "0"))
21
        self.call_limit = int(os.getenv("MSWEA_GLOBAL_CALL_LIMIT", "0"))
22
        if (self.cost_limit > 0 or self.call_limit > 0) and not os.getenv("MSWEA_SILENT_STARTUP"):
23
            print(f"Global cost/call limit: ${self.cost_limit:.4f} / {self.call_limit}")
24
 
25
    def add(self, cost: float) -> None:
26
        """Add a model call with its cost, checking limits."""
27
        with self._lock:
28
            self._cost += cost
29
            self._n_calls += 1
30
        if 0 < self.cost_limit < self._cost or 0 < self.call_limit < self._n_calls + 1:
31
            raise RuntimeError(f"Global cost/call limit exceeded: ${self._cost:.4f} / {self._n_calls}")
32
 
33
    @property
34
    def cost(self) -> float:
35
        return self._cost
36
 
37
    @property
38
    def n_calls(self) -> int:
39
        return self._n_calls
40
 
41
 
42
GLOBAL_MODEL_STATS = GlobalModelStats()
43
 
44
 
45
def get_model(input_model_name: str | None = None, config: dict | None = None) -> Model:
46
    """Get an initialized model object from any kind of user input or settings."""
47
    resolved_model_name = get_model_name(input_model_name, config)
48
    if config is None:
49
        config = {}
50
    config = copy.deepcopy(config)
51
    config["model_name"] = resolved_model_name
52
 
53
    model_class = get_model_class(resolved_model_name, config.pop("model_class", ""))
54
 
55
    if (
56
        any(s in resolved_model_name.lower() for s in ["anthropic", "sonnet", "opus", "claude"])
57
        and "set_cache_control" not in config
58
    ):
59
        # Select cache control for Anthropic models by default
60
        config["set_cache_control"] = "default_end"
61
 
62
    return model_class(**config)
63
 
64
 
65
def get_model_name(input_model_name: str | None = None, config: dict | None = None) -> str:
66
    """Get a model name from any kind of user input or settings."""
67
    if config is None:
68
        config = {}
69
    if input_model_name:
70
        return input_model_name
71
    if from_config := config.get("model_name"):
72
        return from_config
73
    if from_env := os.getenv("MSWEA_MODEL_NAME"):
74
        return from_env
75
    raise ValueError("No default model set. Please run `mini-extra config setup` to set one.")
76
 
77
 
78
_MODEL_CLASS_MAPPING = {
79
    "litellm": "minisweagent.models.litellm_model.LitellmModel",
80
    "litellm_textbased": "minisweagent.models.litellm_textbased_model.LitellmTextbasedModel",
81
    "litellm_response": "minisweagent.models.litellm_response_model.LitellmResponseModel",
82
    "openrouter": "minisweagent.models.openrouter_model.OpenRouterModel",
83
    "openrouter_textbased": "minisweagent.models.openrouter_textbased_model.OpenRouterTextbasedModel",
84
    "openrouter_response": "minisweagent.models.openrouter_response_model.OpenRouterResponseModel",
85
    "portkey": "minisweagent.models.portkey_model.PortkeyModel",
86
    "portkey_response": "minisweagent.models.portkey_response_model.PortkeyResponseAPIModel",
87
    "requesty": "minisweagent.models.requesty_model.RequestyModel",
88
    "deterministic": "minisweagent.models.test_models.DeterministicModel",
89
}
90
 
91
 
92
def get_model_class(model_name: str, model_class: str = "") -> type:
93
    """Select the best model class.
94
 
95
    If a model_class is provided (as shortcut name, or as full import path,
96
    e.g., "anthropic" or "minisweagent.models.anthropic.AnthropicModel"),
97
    it takes precedence over the `model_name`.
98
    Otherwise, the model_name is used to select the best model class.
99
    """
100
    if model_class:
101
        full_path = _MODEL_CLASS_MAPPING.get(model_class, model_class)
102
        try:
103
            module_name, class_name = full_path.rsplit(".", 1)
104
            module = importlib.import_module(module_name)
105
            return getattr(module, class_name)
106
        except (ValueError, ImportError, AttributeError):
107
            msg = f"Unknown model class: {model_class} (resolved to {full_path}, available: {_MODEL_CLASS_MAPPING})"
108
            raise ValueError(msg)
109
 
110
    # Default to LitellmModel
111
    from minisweagent.models.litellm_model import LitellmModel
112
 
113
    return LitellmModel
114
 
114 lines