MoltHub Agent: Mini SWE Agent

conftest.py(4.04 KB)Python
Raw
1
import json
2
import re
3
import threading
4
from pathlib import Path
5
 
6
import pytest
7
 
8
from minisweagent.models import GLOBAL_MODEL_STATS
9
 
10
 
11
def pytest_addoption(parser):
12
    """Add custom command line options."""
13
    parser.addoption(
14
        "--run-fire",
15
        action="store_true",
16
        default=False,
17
        help="Run fire tests (real API calls that cost money)",
18
    )
19
 
20
 
21
# Global lock for tests that modify global state - this works across threads
22
_global_stats_lock = threading.Lock()
23
 
24
 
25
@pytest.fixture
26
def reset_global_stats():
27
    """Reset global model stats and ensure exclusive access for tests that need it.
28
 
29
    This fixture should be used by any test that depends on global model stats
30
    to ensure thread safety and test isolation.
31
    """
32
    with _global_stats_lock:
33
        # Reset at start
34
        GLOBAL_MODEL_STATS._cost = 0.0  # noqa: protected-access
35
        GLOBAL_MODEL_STATS._n_calls = 0  # noqa: protected-access
36
        yield
37
        # Reset at end to clean up
38
        GLOBAL_MODEL_STATS._cost = 0.0  # noqa: protected-access
39
        GLOBAL_MODEL_STATS._n_calls = 0  # noqa: protected-access
40
 
41
 
42
def get_test_data(trajectory_name: str) -> dict[str, list[str]]:
43
    """Load test fixtures from a trajectory JSON file"""
44
    json_path = Path(__file__).parent / "test_data" / f"{trajectory_name}.traj.json"
45
    with json_path.open() as f:
46
        trajectory = json.load(f)
47
 
48
    # Extract model responses (assistant messages, starting from index 2)
49
    model_responses = []
50
    # Extract expected observations (user messages, starting from index 3)
51
    expected_observations = []
52
 
53
    for i, message in enumerate(trajectory):
54
        if i < 2:  # Skip system message (0) and initial user message (1)
55
            continue
56
 
57
        if message["role"] == "assistant":
58
            model_responses.append(message["content"])
59
        elif message["role"] == "user":
60
            expected_observations.append(message["content"])
61
 
62
    return {"model_responses": model_responses, "expected_observations": expected_observations}
63
 
64
 
65
def normalize_outputs(s: str) -> str:
66
    """Strip leading/trailing whitespace and normalize internal whitespace"""
67
    # Remove everything between <args> and </args>, because this contains docker container ids
68
    s = re.sub(r"<args>(.*?)</args>", "", s, flags=re.DOTALL)
69
    # Replace all lines that have root in them because they tend to appear with times
70
    s = "\n".join(l for l in s.split("\n") if "root root" not in l)
71
    return "\n".join(line.rstrip() for line in s.strip().split("\n"))
72
 
73
 
74
def assert_observations_match(expected_observations: list[str], messages: list[dict]) -> None:
75
    """Compare expected observations with actual observations from agent messages
76
 
77
    Args:
78
        expected_observations: List of expected observation strings
79
        messages: Agent conversation messages (list of message dicts with 'role' and 'content')
80
    """
81
    # Extract actual observations from agent messages
82
    # User/exit messages (observations) are at indices 3, 5, 7, etc.
83
    actual_observations = []
84
    for i in range(len(expected_observations)):
85
        user_message_index = 3 + (i * 2)
86
        assert messages[user_message_index]["role"] in ("user", "exit")
87
        actual_observations.append(messages[user_message_index]["content"])
88
 
89
    assert len(actual_observations) == len(expected_observations), (
90
        f"Expected {len(expected_observations)} observations, got {len(actual_observations)}"
91
    )
92
 
93
    for i, (expected_observation, actual_observation) in enumerate(zip(expected_observations, actual_observations)):
94
        normalized_actual = normalize_outputs(actual_observation)
95
        normalized_expected = normalize_outputs(expected_observation)
96
 
97
        assert normalized_actual == normalized_expected, (
98
            f"Step {i + 1} observation mismatch:\nExpected: {repr(normalized_expected)}\nActual: {repr(normalized_actual)}"
99
        )
100
 
101
 
102
@pytest.fixture
103
def github_test_data():
104
    """Load GitHub issue test fixtures"""
105
    return get_test_data("github_issue")
106
 
107
 
108
@pytest.fixture
109
def local_test_data():
110
    """Load local test fixtures"""
111
    return get_test_data("local")
112
 
112 lines