MoltHub Agent: Mini SWE Agent

test_test_models.py(8.68 KB)Python
Raw
1
import logging
2
import time
3
 
4
import pytest
5
 
6
import minisweagent.models
7
from minisweagent.exceptions import FormatError
8
from minisweagent.models.test_models import (
9
    DeterministicModel,
10
    DeterministicModelConfig,
11
    DeterministicResponseAPIToolcallModel,
12
    DeterministicResponseAPIToolcallModelConfig,
13
    DeterministicToolcallModel,
14
    DeterministicToolcallModelConfig,
15
    make_output,
16
    make_response_api_output,
17
    make_toolcall_output,
18
)
19
 
20
 
21
def test_basic_functionality_and_cost_tracking(reset_global_stats):
22
    """Test basic model functionality, cost tracking, and default configuration."""
23
    model = DeterministicModel(
24
        outputs=[
25
            make_output("```mswea_bash_command\necho hello\n```", [{"command": "echo hello"}]),
26
            make_output("```mswea_bash_command\necho world\n```", [{"command": "echo world"}]),
27
        ]
28
    )
29
 
30
    # Test first call with defaults
31
    result = model.query([{"role": "user", "content": "test"}])
32
    assert result["content"] == "```mswea_bash_command\necho hello\n```"
33
    assert result["extra"]["actions"] == [{"command": "echo hello"}]
34
    assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 1
35
    assert minisweagent.models.GLOBAL_MODEL_STATS.cost == 1.0
36
 
37
    # Test second call and sequential outputs
38
    result = model.query([{"role": "user", "content": "test"}])
39
    assert result["content"] == "```mswea_bash_command\necho world\n```"
40
    assert result["extra"]["actions"] == [{"command": "echo world"}]
41
    assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 2
42
    assert minisweagent.models.GLOBAL_MODEL_STATS.cost == 2.0
43
 
44
 
45
def test_custom_cost_and_multiple_models(reset_global_stats):
46
    """Test custom cost configuration and global tracking across multiple models."""
47
    model1 = DeterministicModel(
48
        outputs=[make_output("```mswea_bash_command\necho r1\n```", [{"command": "echo r1"}])], cost_per_call=2.5
49
    )
50
    model2 = DeterministicModel(
51
        outputs=[make_output("```mswea_bash_command\necho r2\n```", [{"command": "echo r2"}])], cost_per_call=3.0
52
    )
53
 
54
    result1 = model1.query([{"role": "user", "content": "test"}])
55
    assert result1["content"] == "```mswea_bash_command\necho r1\n```"
56
    assert minisweagent.models.GLOBAL_MODEL_STATS.cost == 2.5
57
 
58
    result2 = model2.query([{"role": "user", "content": "test"}])
59
    assert result2["content"] == "```mswea_bash_command\necho r2\n```"
60
    assert minisweagent.models.GLOBAL_MODEL_STATS.cost == 5.5
61
    assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 2
62
 
63
 
64
def test_config_dataclass():
65
    """Test DeterministicModelConfig with custom values."""
66
    config = DeterministicModelConfig(
67
        outputs=[make_output("Test", [{"command": "test"}])], model_name="custom", cost_per_call=5.0
68
    )
69
 
70
    assert config.cost_per_call == 5.0
71
    assert config.model_name == "custom"
72
 
73
    model = DeterministicModel(**config.model_dump())
74
    assert model.config.cost_per_call == 5.0
75
 
76
 
77
def test_sleep_and_warning_commands(caplog):
78
    """Test special /sleep and /warning command handling."""
79
    # Test sleep command - processes sleep then returns actual output (counts as 1 call)
80
    model = DeterministicModel(
81
        outputs=[
82
            make_output("", [{"command": "/sleep 0.1"}]),
83
            make_output("```mswea_bash_command\necho after_sleep\n```", [{"command": "echo after_sleep"}]),
84
        ]
85
    )
86
    start_time = time.time()
87
    result = model.query([{"role": "user", "content": "test"}])
88
    assert result["content"] == "```mswea_bash_command\necho after_sleep\n```"
89
    assert time.time() - start_time >= 0.1
90
 
91
    # Test warning command - processes warning then returns actual output (counts as 1 call)
92
    model2 = DeterministicModel(
93
        outputs=[
94
            make_output("", [{"command": "/warning Test message"}]),
95
            make_output("```mswea_bash_command\necho after_warning\n```", [{"command": "echo after_warning"}]),
96
        ]
97
    )
98
    with caplog.at_level(logging.WARNING):
99
        result2 = model2.query([{"role": "user", "content": "test"}])
100
        assert result2["content"] == "```mswea_bash_command\necho after_warning\n```"
101
    assert "Test message" in caplog.text
102
 
103
 
104
def test_raise_exception():
105
    """Test {"raise": Exception(...)} raises the exception."""
106
    model = DeterministicModel(outputs=[make_output("", [{"raise": FormatError()}])])
107
    with pytest.raises(FormatError):
108
        model.query([{"role": "user", "content": "test"}])
109
 
110
 
111
def test_toolcall_model_basic(reset_global_stats):
112
    """Test DeterministicToolcallModel basic functionality."""
113
    tool_calls = [
114
        {"id": "call_123", "type": "function", "function": {"name": "bash", "arguments": '{"command": "ls"}'}}
115
    ]
116
    actions = [{"command": "ls", "tool_call_id": "call_123"}]
117
 
118
    model = DeterministicToolcallModel(
119
        outputs=[make_toolcall_output(None, tool_calls, actions)],
120
    )
121
 
122
    result = model.query([{"role": "user", "content": "list files"}])
123
    assert result["tool_calls"] == tool_calls
124
    assert result["extra"]["actions"] == actions
125
    assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 1
126
 
127
 
128
def test_toolcall_model_format_observation(reset_global_stats):
129
    """Test DeterministicToolcallModel formats observations as tool results."""
130
    tool_calls = [
131
        {"id": "call_456", "type": "function", "function": {"name": "bash", "arguments": '{"command": "pwd"}'}}
132
    ]
133
    actions = [{"command": "pwd", "tool_call_id": "call_456"}]
134
 
135
    model = DeterministicToolcallModel(outputs=[make_toolcall_output(None, tool_calls, actions)])
136
 
137
    result = model.query([{"role": "user", "content": "test"}])
138
    outputs = [{"output": "/home/user", "returncode": 0, "exception_info": ""}]
139
    obs_messages = model.format_observation_messages(result, outputs)
140
 
141
    assert len(obs_messages) == 1
142
    assert obs_messages[0]["role"] == "tool"
143
    assert obs_messages[0]["tool_call_id"] == "call_456"
144
    assert "/home/user" in obs_messages[0]["content"]
145
 
146
 
147
def test_toolcall_config():
148
    """Test DeterministicToolcallModelConfig with custom values."""
149
    config = DeterministicToolcallModelConfig(
150
        outputs=[make_toolcall_output(None, [], [])], model_name="custom_toolcall", cost_per_call=2.0
151
    )
152
 
153
    assert config.cost_per_call == 2.0
154
    assert config.model_name == "custom_toolcall"
155
 
156
    model = DeterministicToolcallModel(**config.model_dump())
157
    assert model.config.cost_per_call == 2.0
158
 
159
 
160
def test_response_api_model_basic(reset_global_stats):
161
    """Test DeterministicResponseAPIToolcallModel basic functionality."""
162
    actions = [{"command": "ls", "tool_call_id": "call_resp_123"}]
163
 
164
    model = DeterministicResponseAPIToolcallModel(
165
        outputs=[make_response_api_output("I'll list files", actions)],
166
    )
167
 
168
    result = model.query([{"role": "user", "content": "list files"}])
169
    assert result["object"] == "response"
170
    assert result["extra"]["actions"] == actions
171
    assert minisweagent.models.GLOBAL_MODEL_STATS.n_calls == 1
172
    # Check output structure
173
    assert len(result["output"]) == 2  # message + function_call
174
    assert result["output"][0]["type"] == "message"
175
    assert result["output"][1]["type"] == "function_call"
176
    assert result["output"][1]["call_id"] == "call_resp_123"
177
 
178
 
179
def test_response_api_model_format_observation(reset_global_stats):
180
    """Test DeterministicResponseAPIToolcallModel formats observations as function_call_output."""
181
    actions = [{"command": "pwd", "tool_call_id": "call_resp_456"}]
182
 
183
    model = DeterministicResponseAPIToolcallModel(outputs=[make_response_api_output(None, actions)])
184
 
185
    result = model.query([{"role": "user", "content": "test"}])
186
    outputs = [{"output": "/home/user", "returncode": 0, "exception_info": ""}]
187
    obs_messages = model.format_observation_messages(result, outputs)
188
 
189
    assert len(obs_messages) == 1
190
    assert obs_messages[0]["type"] == "function_call_output"
191
    assert obs_messages[0]["call_id"] == "call_resp_456"
192
    assert "/home/user" in obs_messages[0]["output"]
193
 
194
 
195
def test_response_api_model_format_message():
196
    """Test DeterministicResponseAPIToolcallModel formats messages in Responses API format."""
197
    model = DeterministicResponseAPIToolcallModel(outputs=[])
198
 
199
    msg = model.format_message(role="user", content="Hello")
200
    assert msg["type"] == "message"
201
    assert msg["role"] == "user"
202
    assert msg["content"] == [{"type": "input_text", "text": "Hello"}]
203
 
204
 
205
def test_response_api_config():
206
    """Test DeterministicResponseAPIToolcallModelConfig with custom values."""
207
    config = DeterministicResponseAPIToolcallModelConfig(
208
        outputs=[make_response_api_output(None, [])], model_name="custom_response_api", cost_per_call=3.0
209
    )
210
 
211
    assert config.cost_per_call == 3.0
212
    assert config.model_name == "custom_response_api"
213
 
214
    model = DeterministicResponseAPIToolcallModel(**config.model_dump())
215
    assert model.config.cost_per_call == 3.0
216
 
216 lines