| 1 | import os
|
| 2 | from unittest.mock import patch
|
| 3 |
|
| 4 | import pytest
|
| 5 |
|
| 6 | from minisweagent.models import GlobalModelStats, get_model, get_model_class, get_model_name
|
| 7 | from minisweagent.models.test_models import DeterministicModel, make_output
|
| 8 |
|
| 9 |
|
| 10 | class TestGetModelName:
|
| 11 | # Common config used across tests - model_name should be direct, not nested under "model"
|
| 12 | CONFIG_WITH_MODEL_NAME = {"model_name": "config-model"}
|
| 13 |
|
| 14 | def test_input_model_name_takes_precedence(self):
|
| 15 | """Test that explicit input_model_name overrides all other sources."""
|
| 16 | with patch.dict(os.environ, {"MSWEA_MODEL_NAME": "env-model"}):
|
| 17 | assert get_model_name("input-model", self.CONFIG_WITH_MODEL_NAME) == "input-model"
|
| 18 |
|
| 19 | def test_config_takes_precedence_over_env(self):
|
| 20 | """Test that config takes precedence over environment variable."""
|
| 21 | with patch.dict(os.environ, {"MSWEA_MODEL_NAME": "env-model"}):
|
| 22 | assert get_model_name(None, self.CONFIG_WITH_MODEL_NAME) == "config-model"
|
| 23 |
|
| 24 | def test_env_var_fallback(self):
|
| 25 | """Test that environment variable is used when no config provided."""
|
| 26 | with patch.dict(os.environ, {"MSWEA_MODEL_NAME": "env-model"}):
|
| 27 | assert get_model_name(None, {}) == "env-model"
|
| 28 |
|
| 29 | def test_config_fallback(self):
|
| 30 | """Test that config model name is used when input and env are missing."""
|
| 31 | with patch.dict(os.environ, {}, clear=True):
|
| 32 | assert get_model_name(None, self.CONFIG_WITH_MODEL_NAME) == "config-model"
|
| 33 |
|
| 34 | def test_raises_error_when_no_model_configured(self):
|
| 35 | """Test that ValueError is raised when no model is configured anywhere."""
|
| 36 | with patch.dict(os.environ, {}, clear=True):
|
| 37 | with pytest.raises(
|
| 38 | ValueError, match="No default model set. Please run `mini-extra config setup` to set one."
|
| 39 | ):
|
| 40 | get_model_name(None, {})
|
| 41 |
|
| 42 | with pytest.raises(
|
| 43 | ValueError, match="No default model set. Please run `mini-extra config setup` to set one."
|
| 44 | ):
|
| 45 | get_model_name(None, None)
|
| 46 |
|
| 47 |
|
| 48 | class TestGetModelClass:
|
| 49 | def test_anthropic_model_selection(self):
|
| 50 | """Test that anthropic-related model names return LitellmModel by default."""
|
| 51 | from minisweagent.models.litellm_model import LitellmModel
|
| 52 |
|
| 53 | for name in ["anthropic", "sonnet", "opus", "claude-sonnet", "claude-opus"]:
|
| 54 | assert get_model_class(name) == LitellmModel
|
| 55 |
|
| 56 | def test_litellm_model_fallback(self):
|
| 57 | """Test that non-anthropic model names return LitellmModel."""
|
| 58 | from minisweagent.models.litellm_model import LitellmModel
|
| 59 |
|
| 60 | for name in ["gpt-4", "gpt-3.5-turbo", "llama2", "random-model"]:
|
| 61 | assert get_model_class(name) == LitellmModel
|
| 62 |
|
| 63 | def test_partial_matches(self):
|
| 64 | """Test that partial string matches work correctly."""
|
| 65 | from minisweagent.models.litellm_model import LitellmModel
|
| 66 |
|
| 67 | assert get_model_class("my-anthropic-model") == LitellmModel
|
| 68 | assert get_model_class("sonnet-latest") == LitellmModel
|
| 69 | assert get_model_class("opus-v2") == LitellmModel
|
| 70 | assert get_model_class("gpt-anthropic-style") == LitellmModel
|
| 71 | assert get_model_class("totally-different") == LitellmModel
|
| 72 |
|
| 73 | def test_litellm_response_model_selection(self):
|
| 74 | """Test that litellm_response model class can be selected."""
|
| 75 | from minisweagent.models.litellm_response_model import LitellmResponseModel
|
| 76 |
|
| 77 | assert get_model_class("any-model", "litellm_response") == LitellmResponseModel
|
| 78 |
|
| 79 |
|
| 80 | class TestGetModel:
|
| 81 | def test_config_deep_copy(self):
|
| 82 | """Test that get_model preserves original config via deep copy."""
|
| 83 | original_config = {"model_kwargs": {"api_key": "original"}, "outputs": [make_output("test", [])]}
|
| 84 |
|
| 85 | with patch("minisweagent.models.get_model_class") as mock_get_class:
|
| 86 | mock_get_class.return_value = lambda **kwargs: DeterministicModel(
|
| 87 | outputs=[make_output("test", [])], model_name="test"
|
| 88 | )
|
| 89 | get_model("test-model", original_config)
|
| 90 | assert original_config["model_kwargs"]["api_key"] == "original"
|
| 91 | assert "model_name" not in original_config
|
| 92 |
|
| 93 | def test_integration_with_compatible_model(self):
|
| 94 | """Test get_model works end-to-end with a model that handles extra kwargs."""
|
| 95 | with patch("minisweagent.models.get_model_class") as mock_get_class:
|
| 96 | hello_output = make_output("hello", [])
|
| 97 |
|
| 98 | def compatible_model(**kwargs):
|
| 99 | # Filter to only what DeterministicModel accepts, provide defaults
|
| 100 | config_args = {k: v for k, v in kwargs.items() if k in ["outputs", "model_name"]}
|
| 101 | if "outputs" not in config_args:
|
| 102 | config_args["outputs"] = [make_output("default", [])]
|
| 103 | return DeterministicModel(**config_args)
|
| 104 |
|
| 105 | mock_get_class.return_value = compatible_model
|
| 106 | model = get_model("test-model", {"outputs": [hello_output]})
|
| 107 | assert isinstance(model, DeterministicModel)
|
| 108 | assert model.config.outputs == [hello_output]
|
| 109 | assert model.config.model_name == "test-model"
|
| 110 |
|
| 111 | def test_config_api_key_used_when_no_env_var(self):
|
| 112 | """Test that config api_key is used when env var is not set."""
|
| 113 | with patch.dict(os.environ, {}, clear=True):
|
| 114 | config = {"model_kwargs": {"api_key": "config-key"}, "model_class": "litellm"}
|
| 115 | model = get_model("test-model", config)
|
| 116 |
|
| 117 | # LitellmModel stores the api_key in model_kwargs
|
| 118 | assert model.config.model_kwargs["api_key"] == "config-key"
|
| 119 |
|
| 120 | def test_no_api_key_when_none_provided(self):
|
| 121 | """Test that no api_key is set when neither env var nor config provide one."""
|
| 122 | with patch.dict(os.environ, {}, clear=True):
|
| 123 | config = {"model_class": "litellm"}
|
| 124 | model = get_model("test-model", config)
|
| 125 |
|
| 126 | # LitellmModel should not have api_key when none provided
|
| 127 | model_kwargs = getattr(model.config, "model_kwargs", {})
|
| 128 | assert "api_key" not in model_kwargs
|
| 129 |
|
| 130 | def test_get_deterministic_model(self):
|
| 131 | """Test that get_model can instantiate DeterministicModel via model_class parameter."""
|
| 132 | outputs = [make_output("hello", []), make_output("world", [])]
|
| 133 | config = {"outputs": outputs, "cost_per_call": 2.0}
|
| 134 | model = get_model("test-model", config | {"model_class": "deterministic"})
|
| 135 |
|
| 136 | assert isinstance(model, DeterministicModel)
|
| 137 | assert model.config.outputs == outputs
|
| 138 | assert model.config.cost_per_call == 2.0
|
| 139 | assert model.config.model_name == "test-model"
|
| 140 |
|
| 141 |
|
| 142 | class TestGlobalModelStats:
|
| 143 | def test_prints_cost_limit_when_set(self, capsys):
|
| 144 | """Test that cost limit is printed when MSWEA_GLOBAL_COST_LIMIT is set."""
|
| 145 | with patch.dict(os.environ, {"MSWEA_GLOBAL_COST_LIMIT": "5.5"}, clear=True):
|
| 146 | GlobalModelStats()
|
| 147 | captured = capsys.readouterr()
|
| 148 | assert "Global cost/call limit: $5.5000 / 0" in captured.out
|
| 149 |
|
| 150 | def test_prints_call_limit_when_set(self, capsys):
|
| 151 | """Test that call limit is printed when MSWEA_GLOBAL_CALL_LIMIT is set."""
|
| 152 | with patch.dict(os.environ, {"MSWEA_GLOBAL_CALL_LIMIT": "10"}, clear=True):
|
| 153 | GlobalModelStats()
|
| 154 | captured = capsys.readouterr()
|
| 155 | assert "Global cost/call limit: $0.0000 / 10" in captured.out
|
| 156 |
|
| 157 | def test_prints_both_limits_when_both_set(self, capsys):
|
| 158 | """Test that both limits are printed when both environment variables are set."""
|
| 159 | with patch.dict(os.environ, {"MSWEA_GLOBAL_COST_LIMIT": "2.5", "MSWEA_GLOBAL_CALL_LIMIT": "5"}, clear=True):
|
| 160 | GlobalModelStats()
|
| 161 | captured = capsys.readouterr()
|
| 162 | assert "Global cost/call limit: $2.5000 / 5" in captured.out
|
| 163 |
|
| 164 | def test_no_print_when_silent_startup_set(self, capsys):
|
| 165 | """Test that limits are not printed when MSWEA_SILENT_STARTUP is set."""
|
| 166 | with patch.dict(
|
| 167 | os.environ,
|
| 168 | {"MSWEA_GLOBAL_COST_LIMIT": "5.0", "MSWEA_GLOBAL_CALL_LIMIT": "10", "MSWEA_SILENT_STARTUP": "1"},
|
| 169 | clear=True,
|
| 170 | ):
|
| 171 | GlobalModelStats()
|
| 172 | captured = capsys.readouterr()
|
| 173 | assert "Global cost/call limit" not in captured.out
|
| 174 |
|
| 175 | def test_no_print_when_no_limits_set(self, capsys):
|
| 176 | """Test that nothing is printed when no limits are set."""
|
| 177 | with patch.dict(os.environ, {}, clear=True):
|
| 178 | GlobalModelStats()
|
| 179 | captured = capsys.readouterr()
|
| 180 | assert "Global cost/call limit" not in captured.out
|
| 181 |
|
| 182 | def test_no_print_when_limits_are_zero(self, capsys):
|
| 183 | """Test that nothing is printed when limits are explicitly set to zero."""
|
| 184 | with patch.dict(os.environ, {"MSWEA_GLOBAL_COST_LIMIT": "0", "MSWEA_GLOBAL_CALL_LIMIT": "0"}, clear=True):
|
| 185 | GlobalModelStats()
|
| 186 | captured = capsys.readouterr()
|
| 187 | assert "Global cost/call limit" not in captured.out
|
| 188 |
|