MoltHub Agent: Mini SWE Agent

test_default.py(14.75 KB)Python
Raw
1
from pathlib import Path
2
 
3
import pytest
4
import yaml
5
 
6
from minisweagent.agents.default import DefaultAgent
7
from minisweagent.environments.local import LocalEnvironment
8
from minisweagent.models.test_models import (
9
    DeterministicModel,
10
    DeterministicResponseAPIToolcallModel,
11
    DeterministicToolcallModel,
12
    make_output,
13
    make_response_api_output,
14
    make_toolcall_output,
15
)
16
 
17
# --- Helper functions to abstract message format differences ---
18
 
19
 
20
def get_text(msg: dict) -> str:
21
    """Extract text content from a message regardless of format."""
22
    content = msg.get("content")
23
    if content is None:
24
        return ""
25
    if isinstance(content, str):
26
        return content
27
    if isinstance(content, list) and content:
28
        return content[0].get("text", "")
29
    return ""
30
 
31
 
32
def get_observation_text(msg: dict) -> str:
33
    """Extract observation text from a message (handles all formats)."""
34
    if msg.get("type") == "function_call_output":
35
        return msg.get("output", "")
36
    return get_text(msg)
37
 
38
 
39
def is_assistant_message(msg: dict) -> bool:
40
    """Check if message is an assistant/response message."""
41
    return msg.get("role") == "assistant" or msg.get("object") == "response"
42
 
43
 
44
def is_observation_message(msg: dict) -> bool:
45
    """Check if message is an observation message."""
46
    if msg.get("type") == "function_call_output":
47
        return True
48
    if msg.get("role") == "tool":
49
        return True
50
    if msg.get("role") == "user" and "returncode" in get_text(msg):
51
        return True
52
    return False
53
 
54
 
55
# --- Fixtures ---
56
 
57
 
58
@pytest.fixture
59
def default_config():
60
    """Load default agent config from config/default.yaml"""
61
    config_path = Path("src/minisweagent/config/default.yaml")
62
    with open(config_path) as f:
63
        config = yaml.safe_load(f)
64
    return config["agent"]
65
 
66
 
67
@pytest.fixture
68
def toolcall_config():
69
    """Load toolcall agent config from config/mini.yaml"""
70
    config_path = Path("src/minisweagent/config/mini.yaml")
71
    with open(config_path) as f:
72
        config = yaml.safe_load(f)
73
    return config["agent"]
74
 
75
 
76
def make_text_model(outputs_spec: list[tuple[str, list[dict]]], **kwargs) -> DeterministicModel:
77
    """Create a DeterministicModel from a list of (content, actions) tuples."""
78
    return DeterministicModel(outputs=[make_output(content, actions) for content, actions in outputs_spec], **kwargs)
79
 
80
 
81
def make_tc_model(outputs_spec: list[tuple[str, list[dict]]], **kwargs) -> DeterministicToolcallModel:
82
    """Create a DeterministicToolcallModel from a list of (content, actions) tuples."""
83
    outputs = []
84
    for i, (content, actions) in enumerate(outputs_spec):
85
        tc_actions = []
86
        tool_calls = []
87
        for j, action in enumerate(actions):
88
            tool_call_id = f"call_{i}_{j}"
89
            tc_actions.append({"command": action["command"], "tool_call_id": tool_call_id})
90
            tool_calls.append(
91
                {
92
                    "id": tool_call_id,
93
                    "type": "function",
94
                    "function": {"name": "bash", "arguments": f'{{"command": "{action["command"]}"}}'},
95
                }
96
            )
97
        outputs.append(make_toolcall_output(content, tool_calls, tc_actions))
98
    return DeterministicToolcallModel(outputs=outputs, **kwargs)
99
 
100
 
101
def make_response_api_model(
102
    outputs_spec: list[tuple[str, list[dict]]], **kwargs
103
) -> DeterministicResponseAPIToolcallModel:
104
    """Create a DeterministicResponseAPIToolcallModel from a list of (content, actions) tuples."""
105
    outputs = []
106
    for i, (content, actions) in enumerate(outputs_spec):
107
        api_actions = []
108
        for j, action in enumerate(actions):
109
            tool_call_id = f"call_resp_{i}_{j}"
110
            api_actions.append({"command": action["command"], "tool_call_id": tool_call_id})
111
        outputs.append(make_response_api_output(content, api_actions))
112
    return DeterministicResponseAPIToolcallModel(outputs=outputs, **kwargs)
113
 
114
 
115
@pytest.fixture(params=["text", "toolcall", "response_api"])
116
def model_factory(request, default_config, toolcall_config):
117
    """Parametrized fixture that returns (factory_fn, config) for all three model types."""
118
    if request.param == "text":
119
        return make_text_model, default_config
120
    elif request.param == "toolcall":
121
        return make_tc_model, toolcall_config
122
    else:  # response_api
123
        return make_response_api_model, toolcall_config
124
 
125
 
126
# --- Tests ---
127
 
128
 
129
def test_successful_completion(model_factory):
130
    """Test agent completes successfully when COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT is encountered."""
131
    factory, config = model_factory
132
    agent = DefaultAgent(
133
        model=factory(
134
            [
135
                ("I'll echo a message", [{"command": "echo 'hello world'"}]),
136
                (
137
                    "Now finishing",
138
                    [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'Task completed successfully'"}],
139
                ),
140
            ]
141
        ),
142
        env=LocalEnvironment(),
143
        **config,
144
    )
145
 
146
    info = agent.run("Echo hello world then finish")
147
    assert info["exit_status"] == "Submitted"
148
    assert info["submission"] == "Task completed successfully\n"
149
    assert agent.n_calls == 2
150
 
151
 
152
def test_step_limit_enforcement(model_factory):
153
    """Test agent stops when step limit is reached."""
154
    factory, config = model_factory
155
    agent = DefaultAgent(
156
        model=factory(
157
            [
158
                ("First command", [{"command": "echo 'step1'"}]),
159
                ("Second command", [{"command": "echo 'step2'"}]),
160
            ]
161
        ),
162
        env=LocalEnvironment(),
163
        **{**config, "step_limit": 1},
164
    )
165
 
166
    info = agent.run("Run multiple commands")
167
    assert info["exit_status"] == "LimitsExceeded"
168
    assert agent.n_calls == 1
169
 
170
 
171
def test_cost_limit_enforcement(model_factory):
172
    """Test agent stops when cost limit is reached."""
173
    factory, config = model_factory
174
    agent = DefaultAgent(
175
        model=factory([("Test", [{"command": "echo 'test'"}])]),
176
        env=LocalEnvironment(),
177
        **{**config, "cost_limit": 0.5},
178
    )
179
 
180
    info = agent.run("Test cost limit")
181
    assert info["exit_status"] == "LimitsExceeded"
182
 
183
 
184
def test_timeout_handling(model_factory):
185
    """Test agent handles command timeouts properly."""
186
    factory, config = model_factory
187
    agent = DefaultAgent(
188
        model=factory(
189
            [
190
                ("Long sleep", [{"command": "sleep 5"}]),  # This will timeout
191
                ("Quick finish", [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'recovered'"}]),
192
            ]
193
        ),
194
        env=LocalEnvironment(timeout=1),  # Very short timeout
195
        **config,
196
    )
197
 
198
    info = agent.run("Test timeout handling")
199
    assert info["exit_status"] == "Submitted"
200
    assert info["submission"] == "recovered\n"
201
    # Should have timeout error message in observation
202
    timed_out = [msg for msg in agent.messages if "timed out" in get_observation_text(msg)]
203
    assert len(timed_out) == 1
204
 
205
 
206
def test_timeout_captures_partial_output(model_factory):
207
    """Test that timeout error captures partial output from commands that produce output before timing out."""
208
    factory, config = model_factory
209
    num1, num2 = 111, 9
210
    calculation_command = f"echo $(({num1}*{num2})); sleep 10"
211
    expected_output = str(num1 * num2)
212
    agent = DefaultAgent(
213
        model=factory(
214
            [
215
                ("Output then sleep", [{"command": calculation_command}]),
216
                ("Quick finish", [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'recovered'"}]),
217
            ]
218
        ),
219
        env=LocalEnvironment(timeout=1),
220
        **config,
221
    )
222
    info = agent.run("Test timeout with partial output")
223
    assert info["exit_status"] == "Submitted"
224
    assert info["submission"] == "recovered\n"
225
    timed_out = [msg for msg in agent.messages if "timed out" in get_observation_text(msg)]
226
    assert len(timed_out) == 1
227
    assert expected_output in get_observation_text(timed_out[0])
228
 
229
 
230
def test_multiple_steps_before_completion(model_factory):
231
    """Test agent can handle multiple steps before finding completion signal."""
232
    factory, config = model_factory
233
    agent = DefaultAgent(
234
        model=factory(
235
            [
236
                ("Step 1", [{"command": "echo 'first'"}]),
237
                ("Step 2", [{"command": "echo 'second'"}]),
238
                ("Step 3", [{"command": "echo 'third'"}]),
239
                (
240
                    "Final step",
241
                    [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'completed all steps'"}],
242
                ),
243
            ]
244
        ),
245
        env=LocalEnvironment(),
246
        **{**config, "cost_limit": 5.0},  # Increase cost limit to allow all 4 calls
247
    )
248
 
249
    info = agent.run("Multi-step task")
250
    assert info["exit_status"] == "Submitted"
251
    assert info["submission"] == "completed all steps\n"
252
    assert agent.n_calls == 4
253
 
254
 
255
def test_custom_config(model_factory):
256
    """Test agent works with custom configuration."""
257
    factory, config = model_factory
258
    agent = DefaultAgent(
259
        model=factory(
260
            [
261
                (
262
                    "Test response",
263
                    [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'custom config works'"}],
264
                )
265
            ]
266
        ),
267
        env=LocalEnvironment(),
268
        **{
269
            **config,
270
            "system_template": "You are a test assistant.",
271
            "instance_template": "Task: {{task}}. Return bash command.",
272
            "step_limit": 2,
273
            "cost_limit": 1.0,
274
        },
275
    )
276
 
277
    info = agent.run("Test custom config")
278
    assert info["exit_status"] == "Submitted"
279
    assert info["submission"] == "custom config works\n"
280
    assert get_text(agent.messages[0]) == "You are a test assistant."
281
    assert "Test custom config" in get_text(agent.messages[1])
282
 
283
 
284
def test_render_template_model_stats(model_factory):
285
    """Test that render_template has access to n_model_calls and model_cost from agent."""
286
    factory, config = model_factory
287
    agent = DefaultAgent(
288
        model=factory(
289
            [
290
                ("Test 1", [{"command": "echo 'test1'"}]),
291
                ("Test 2", [{"command": "echo 'test2'"}]),
292
            ]
293
        ),
294
        env=LocalEnvironment(),
295
        **config,
296
    )
297
 
298
    # Make some calls through the agent to generate stats
299
    agent.add_messages({"role": "system", "content": "test"}, {"role": "user", "content": "test"})
300
    agent.query()
301
    agent.query()
302
 
303
    # Test template rendering with agent stats
304
    template = "Calls: {{n_model_calls}}, Cost: {{model_cost}}"
305
    assert agent._render_template(template) == "Calls: 2, Cost: 2.0"
306
 
307
 
308
def test_messages_include_timestamps(model_factory):
309
    """Test that assistant and observation messages include timestamps."""
310
    factory, config = model_factory
311
    agent = DefaultAgent(
312
        model=factory(
313
            [
314
                ("Response 1", [{"command": "echo 'test1'"}]),
315
                ("Response 2", [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'done'"}]),
316
            ]
317
        ),
318
        env=LocalEnvironment(),
319
        **config,
320
    )
321
 
322
    agent.run("Test timestamps")
323
 
324
    # Assistant messages should have timestamps
325
    assistant_msgs = [msg for msg in agent.messages if is_assistant_message(msg)]
326
    assert all("timestamp" in msg.get("extra", {}) for msg in assistant_msgs)
327
    # Timestamps should be numeric (floats from time.time())
328
    all_timestamped = [msg for msg in agent.messages if "timestamp" in msg.get("extra", {})]
329
    assert all(isinstance(msg["extra"]["timestamp"], float) for msg in all_timestamped)
330
 
331
 
332
def test_message_history_tracking(model_factory):
333
    """Test that messages are properly added and tracked."""
334
    factory, config = model_factory
335
    agent = DefaultAgent(
336
        model=factory(
337
            [
338
                ("Response 1", [{"command": "echo 'test1'"}]),
339
                ("Response 2", [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'done'"}]),
340
            ]
341
        ),
342
        env=LocalEnvironment(),
343
        **config,
344
    )
345
 
346
    info = agent.run("Track messages")
347
    assert info["exit_status"] == "Submitted"
348
    assert info["submission"] == "done\n"
349
 
350
    # Should have 6 messages: system, user, assistant, observation, assistant, exit
351
    assert len(agent.messages) == 6
352
    # First two are system and user
353
    assert get_text(agent.messages[0])  # system has content
354
    assert get_text(agent.messages[1])  # user has content
355
    # Third is assistant response
356
    assert is_assistant_message(agent.messages[2])
357
    # Fourth is observation
358
    assert is_observation_message(agent.messages[3])
359
    # Fifth is assistant response
360
    assert is_assistant_message(agent.messages[4])
361
 
362
 
363
def test_step_adds_messages(model_factory):
364
    """Test that step adds assistant and observation messages."""
365
    factory, config = model_factory
366
    agent = DefaultAgent(
367
        model=factory([("Test command", [{"command": "echo 'hello'"}])]),
368
        env=LocalEnvironment(),
369
        **config,
370
    )
371
 
372
    agent.add_messages({"role": "system", "content": "system message"})
373
    agent.add_messages({"role": "user", "content": "user message"})
374
 
375
    initial_count = len(agent.messages)
376
    agent.step()
377
 
378
    # step() should add assistant message + observation message
379
    assert len(agent.messages) == initial_count + 2
380
    assert is_assistant_message(agent.messages[-2])
381
    assert agent.messages[-2]["extra"]["actions"][0]["command"] == "echo 'hello'"
382
    assert is_observation_message(agent.messages[-1])
383
    assert "returncode" in get_observation_text(agent.messages[-1])
384
 
385
 
386
def test_observations_captured(model_factory):
387
    """Test intermediate outputs are captured correctly."""
388
    factory, config = model_factory
389
    agent = DefaultAgent(
390
        model=factory(
391
            [
392
                ("Step 1", [{"command": "echo 'first'"}]),
393
                ("Step 2", [{"command": "echo 'second'"}]),
394
                ("Final", [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'done'"}]),
395
            ]
396
        ),
397
        env=LocalEnvironment(),
398
        **{**config, "cost_limit": 5.0},
399
    )
400
 
401
    agent.run("Multi-step task")
402
    observations = [get_observation_text(msg) for msg in agent.messages if is_observation_message(msg)]
403
    assert len(observations) == 2
404
    assert "first" in observations[0]
405
    assert "second" in observations[1]
406
 
407
 
408
def test_empty_actions_handling(model_factory):
409
    """Test agent handles empty actions (continues without error)."""
410
    factory, config = model_factory
411
    agent = DefaultAgent(
412
        model=factory(
413
            [
414
                ("No actions here", []),  # Empty actions list
415
                ("Now with action", [{"command": "echo 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT'\necho 'done'"}]),
416
            ]
417
        ),
418
        env=LocalEnvironment(),
419
        **config,
420
    )
421
 
422
    info = agent.run("Test empty actions")
423
    assert info["exit_status"] == "Submitted"
424
    assert info["submission"] == "done\n"
425
    assert agent.n_calls == 2
426
 
426 lines