| 1 | import logging
|
| 2 | import time
|
| 3 |
|
| 4 | import pytest
|
| 5 |
|
| 6 | import minisweagent.models
|
| 7 | from minisweagent.exceptions import FormatError
|
| 8 | from minisweagent.models.test_models import (
|
| 9 | DeterministicModel,
|
| 10 | DeterministicModelConfig,
|
| 11 | DeterministicResponseAPIToolcallModel,
|
| 12 | DeterministicResponseAPIToolcallModelConfig,
|
| 13 | DeterministicToolcallModel,
|
| 14 | DeterministicToolcallModelConfig,
|
| 15 | make_output,
|
| 16 | make_response_api_output,
|
| 17 | make_toolcall_output,
|
| 18 | )
|
| 19 |
|
| 20 |
|
| 21 | def test_basic_functionality_and_cost_tracking(reset_global_stats):
|
| 22 | """Test basic model functionality, cost tracking, and default configuration."""
|
| 23 | model = DeterministicModel(
|
| 24 | outputs=[
|
| 25 | make_output("```mswea_bash_command\necho hello\n```", [{"command": "echo hello"}]),
|
| 26 | make_output("```mswea_bash_command\necho world\n```", [{"command": "echo world"}]),
|
| 27 | ]
|
| 28 | )
|
| 29 |
|
| 30 | # Test first call with defaults
|
| 31 | result = model.query([{"role": "user", "content": "test"}])
|
| 32 | assert result["content"] == "```mswea_bash_command\necho hello\n```"
|
| 33 | assert result["extra"]["actions"] == [{"command": "echo hello"}]
|
| 34 | assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 1
|
| 35 | assert minisweagent.models.GLOBAL_MODEL_STATS.cost == 1.0
|
| 36 |
|
| 37 | # Test second call and sequential outputs
|
| 38 | result = model.query([{"role": "user", "content": "test"}])
|
| 39 | assert result["content"] == "```mswea_bash_command\necho world\n```"
|
| 40 | assert result["extra"]["actions"] == [{"command": "echo world"}]
|
| 41 | assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 2
|
| 42 | assert minisweagent.models.GLOBAL_MODEL_STATS.cost == 2.0
|
| 43 |
|
| 44 |
|
| 45 | def test_custom_cost_and_multiple_models(reset_global_stats):
|
| 46 | """Test custom cost configuration and global tracking across multiple models."""
|
| 47 | model1 = DeterministicModel(
|
| 48 | outputs=[make_output("```mswea_bash_command\necho r1\n```", [{"command": "echo r1"}])], cost_per_call=2.5
|
| 49 | )
|
| 50 | model2 = DeterministicModel(
|
| 51 | outputs=[make_output("```mswea_bash_command\necho r2\n```", [{"command": "echo r2"}])], cost_per_call=3.0
|
| 52 | )
|
| 53 |
|
| 54 | result1 = model1.query([{"role": "user", "content": "test"}])
|
| 55 | assert result1["content"] == "```mswea_bash_command\necho r1\n```"
|
| 56 | assert minisweagent.models.GLOBAL_MODEL_STATS.cost == 2.5
|
| 57 |
|
| 58 | result2 = model2.query([{"role": "user", "content": "test"}])
|
| 59 | assert result2["content"] == "```mswea_bash_command\necho r2\n```"
|
| 60 | assert minisweagent.models.GLOBAL_MODEL_STATS.cost == 5.5
|
| 61 | assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 2
|
| 62 |
|
| 63 |
|
| 64 | def test_config_dataclass():
|
| 65 | """Test DeterministicModelConfig with custom values."""
|
| 66 | config = DeterministicModelConfig(
|
| 67 | outputs=[make_output("Test", [{"command": "test"}])], model_name="custom", cost_per_call=5.0
|
| 68 | )
|
| 69 |
|
| 70 | assert config.cost_per_call == 5.0
|
| 71 | assert config.model_name == "custom"
|
| 72 |
|
| 73 | model = DeterministicModel(**config.model_dump())
|
| 74 | assert model.config.cost_per_call == 5.0
|
| 75 |
|
| 76 |
|
| 77 | def test_sleep_and_warning_commands(caplog):
|
| 78 | """Test special /sleep and /warning command handling."""
|
| 79 | # Test sleep command - processes sleep then returns actual output (counts as 1 call)
|
| 80 | model = DeterministicModel(
|
| 81 | outputs=[
|
| 82 | make_output("", [{"command": "/sleep 0.1"}]),
|
| 83 | make_output("```mswea_bash_command\necho after_sleep\n```", [{"command": "echo after_sleep"}]),
|
| 84 | ]
|
| 85 | )
|
| 86 | start_time = time.time()
|
| 87 | result = model.query([{"role": "user", "content": "test"}])
|
| 88 | assert result["content"] == "```mswea_bash_command\necho after_sleep\n```"
|
| 89 | assert time.time() - start_time >= 0.1
|
| 90 |
|
| 91 | # Test warning command - processes warning then returns actual output (counts as 1 call)
|
| 92 | model2 = DeterministicModel(
|
| 93 | outputs=[
|
| 94 | make_output("", [{"command": "/warning Test message"}]),
|
| 95 | make_output("```mswea_bash_command\necho after_warning\n```", [{"command": "echo after_warning"}]),
|
| 96 | ]
|
| 97 | )
|
| 98 | with caplog.at_level(logging.WARNING):
|
| 99 | result2 = model2.query([{"role": "user", "content": "test"}])
|
| 100 | assert result2["content"] == "```mswea_bash_command\necho after_warning\n```"
|
| 101 | assert "Test message" in caplog.text
|
| 102 |
|
| 103 |
|
| 104 | def test_raise_exception():
|
| 105 | """Test {"raise": Exception(...)} raises the exception."""
|
| 106 | model = DeterministicModel(outputs=[make_output("", [{"raise": FormatError()}])])
|
| 107 | with pytest.raises(FormatError):
|
| 108 | model.query([{"role": "user", "content": "test"}])
|
| 109 |
|
| 110 |
|
| 111 | def test_toolcall_model_basic(reset_global_stats):
|
| 112 | """Test DeterministicToolcallModel basic functionality."""
|
| 113 | tool_calls = [
|
| 114 | {"id": "call_123", "type": "function", "function": {"name": "bash", "arguments": '{"command": "ls"}'}}
|
| 115 | ]
|
| 116 | actions = [{"command": "ls", "tool_call_id": "call_123"}]
|
| 117 |
|
| 118 | model = DeterministicToolcallModel(
|
| 119 | outputs=[make_toolcall_output(None, tool_calls, actions)],
|
| 120 | )
|
| 121 |
|
| 122 | result = model.query([{"role": "user", "content": "list files"}])
|
| 123 | assert result["tool_calls"] == tool_calls
|
| 124 | assert result["extra"]["actions"] == actions
|
| 125 | assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 1
|
| 126 |
|
| 127 |
|
| 128 | def test_toolcall_model_format_observation(reset_global_stats):
|
| 129 | """Test DeterministicToolcallModel formats observations as tool results."""
|
| 130 | tool_calls = [
|
| 131 | {"id": "call_456", "type": "function", "function": {"name": "bash", "arguments": '{"command": "pwd"}'}}
|
| 132 | ]
|
| 133 | actions = [{"command": "pwd", "tool_call_id": "call_456"}]
|
| 134 |
|
| 135 | model = DeterministicToolcallModel(outputs=[make_toolcall_output(None, tool_calls, actions)])
|
| 136 |
|
| 137 | result = model.query([{"role": "user", "content": "test"}])
|
| 138 | outputs = [{"output": "/home/user", "returncode": 0, "exception_info": ""}]
|
| 139 | obs_messages = model.format_observation_messages(result, outputs)
|
| 140 |
|
| 141 | assert len(obs_messages) == 1
|
| 142 | assert obs_messages[0]["role"] == "tool"
|
| 143 | assert obs_messages[0]["tool_call_id"] == "call_456"
|
| 144 | assert "/home/user" in obs_messages[0]["content"]
|
| 145 |
|
| 146 |
|
| 147 | def test_toolcall_config():
|
| 148 | """Test DeterministicToolcallModelConfig with custom values."""
|
| 149 | config = DeterministicToolcallModelConfig(
|
| 150 | outputs=[make_toolcall_output(None, [], [])], model_name="custom_toolcall", cost_per_call=2.0
|
| 151 | )
|
| 152 |
|
| 153 | assert config.cost_per_call == 2.0
|
| 154 | assert config.model_name == "custom_toolcall"
|
| 155 |
|
| 156 | model = DeterministicToolcallModel(**config.model_dump())
|
| 157 | assert model.config.cost_per_call == 2.0
|
| 158 |
|
| 159 |
|
| 160 | def test_response_api_model_basic(reset_global_stats):
|
| 161 | """Test DeterministicResponseAPIToolcallModel basic functionality."""
|
| 162 | actions = [{"command": "ls", "tool_call_id": "call_resp_123"}]
|
| 163 |
|
| 164 | model = DeterministicResponseAPIToolcallModel(
|
| 165 | outputs=[make_response_api_output("I'll list files", actions)],
|
| 166 | )
|
| 167 |
|
| 168 | result = model.query([{"role": "user", "content": "list files"}])
|
| 169 | assert result["object"] == "response"
|
| 170 | assert result["extra"]["actions"] == actions
|
| 171 | assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 1
|
| 172 | # Check output structure
|
| 173 | assert len(result["output"]) == 2 # message + function_call
|
| 174 | assert result["output"][0]["type"] == "message"
|
| 175 | assert result["output"][1]["type"] == "function_call"
|
| 176 | assert result["output"][1]["call_id"] == "call_resp_123"
|
| 177 |
|
| 178 |
|
| 179 | def test_response_api_model_format_observation(reset_global_stats):
|
| 180 | """Test DeterministicResponseAPIToolcallModel formats observations as function_call_output."""
|
| 181 | actions = [{"command": "pwd", "tool_call_id": "call_resp_456"}]
|
| 182 |
|
| 183 | model = DeterministicResponseAPIToolcallModel(outputs=[make_response_api_output(None, actions)])
|
| 184 |
|
| 185 | result = model.query([{"role": "user", "content": "test"}])
|
| 186 | outputs = [{"output": "/home/user", "returncode": 0, "exception_info": ""}]
|
| 187 | obs_messages = model.format_observation_messages(result, outputs)
|
| 188 |
|
| 189 | assert len(obs_messages) == 1
|
| 190 | assert obs_messages[0]["type"] == "function_call_output"
|
| 191 | assert obs_messages[0]["call_id"] == "call_resp_456"
|
| 192 | assert "/home/user" in obs_messages[0]["output"]
|
| 193 |
|
| 194 |
|
| 195 | def test_response_api_model_format_message():
|
| 196 | """Test DeterministicResponseAPIToolcallModel formats messages in Responses API format."""
|
| 197 | model = DeterministicResponseAPIToolcallModel(outputs=[])
|
| 198 |
|
| 199 | msg = model.format_message(role="user", content="Hello")
|
| 200 | assert msg["type"] == "message"
|
| 201 | assert msg["role"] == "user"
|
| 202 | assert msg["content"] == [{"type": "input_text", "text": "Hello"}]
|
| 203 |
|
| 204 |
|
| 205 | def test_response_api_config():
|
| 206 | """Test DeterministicResponseAPIToolcallModelConfig with custom values."""
|
| 207 | config = DeterministicResponseAPIToolcallModelConfig(
|
| 208 | outputs=[make_response_api_output(None, [])], model_name="custom_response_api", cost_per_call=3.0
|
| 209 | )
|
| 210 |
|
| 211 | assert config.cost_per_call == 3.0
|
| 212 | assert config.model_name == "custom_response_api"
|
| 213 |
|
| 214 | model = DeterministicResponseAPIToolcallModel(**config.model_dump())
|
| 215 | assert model.config.cost_per_call == 3.0
|
| 216 |
|