| 1 | import json
|
| 2 | import re
|
| 3 | import threading
|
| 4 | from pathlib import Path
|
| 5 |
|
| 6 | import pytest
|
| 7 |
|
| 8 | from minisweagent.models import GLOBAL_MODEL_STATS
|
| 9 |
|
| 10 |
|
| 11 | def pytest_addoption(parser):
|
| 12 | """Add custom command line options."""
|
| 13 | parser.addoption(
|
| 14 | "--run-fire",
|
| 15 | action="store_true",
|
| 16 | default=False,
|
| 17 | help="Run fire tests (real API calls that cost money)",
|
| 18 | )
|
| 19 |
|
| 20 |
|
| 21 | # Global lock for tests that modify global state - this works across threads
|
| 22 | _global_stats_lock = threading.Lock()
|
| 23 |
|
| 24 |
|
| 25 | @pytest.fixture
|
| 26 | def reset_global_stats():
|
| 27 | """Reset global model stats and ensure exclusive access for tests that need it.
|
| 28 |
|
| 29 | This fixture should be used by any test that depends on global model stats
|
| 30 | to ensure thread safety and test isolation.
|
| 31 | """
|
| 32 | with _global_stats_lock:
|
| 33 | # Reset at start
|
| 34 | GLOBAL_MODEL_STATS._cost = 0.0 # noqa: protected-access
|
| 35 | GLOBAL_MODEL_STATS._n_calls = 0 # noqa: protected-access
|
| 36 | yield
|
| 37 | # Reset at end to clean up
|
| 38 | GLOBAL_MODEL_STATS._cost = 0.0 # noqa: protected-access
|
| 39 | GLOBAL_MODEL_STATS._n_calls = 0 # noqa: protected-access
|
| 40 |
|
| 41 |
|
| 42 | def get_test_data(trajectory_name: str) -> dict[str, list[str]]:
|
| 43 | """Load test fixtures from a trajectory JSON file"""
|
| 44 | json_path = Path(__file__).parent / "test_data" / f"{trajectory_name}.traj.json"
|
| 45 | with json_path.open() as f:
|
| 46 | trajectory = json.load(f)
|
| 47 |
|
| 48 | # Extract model responses (assistant messages, starting from index 2)
|
| 49 | model_responses = []
|
| 50 | # Extract expected observations (user messages, starting from index 3)
|
| 51 | expected_observations = []
|
| 52 |
|
| 53 | for i, message in enumerate(trajectory):
|
| 54 | if i < 2: # Skip system message (0) and initial user message (1)
|
| 55 | continue
|
| 56 |
|
| 57 | if message["role"] == "assistant":
|
| 58 | model_responses.append(message["content"])
|
| 59 | elif message["role"] == "user":
|
| 60 | expected_observations.append(message["content"])
|
| 61 |
|
| 62 | return {"model_responses": model_responses, "expected_observations": expected_observations}
|
| 63 |
|
| 64 |
|
| 65 | def normalize_outputs(s: str) -> str:
|
| 66 | """Strip leading/trailing whitespace and normalize internal whitespace"""
|
| 67 | # Remove everything between <args> and </args>, because this contains docker container ids
|
| 68 | s = re.sub(r"<args>(.*?)</args>", "", s, flags=re.DOTALL)
|
| 69 | # Replace all lines that have root in them because they tend to appear with times
|
| 70 | s = "\n".join(l for l in s.split("\n") if "root root" not in l)
|
| 71 | return "\n".join(line.rstrip() for line in s.strip().split("\n"))
|
| 72 |
|
| 73 |
|
| 74 | def assert_observations_match(expected_observations: list[str], messages: list[dict]) -> None:
|
| 75 | """Compare expected observations with actual observations from agent messages
|
| 76 |
|
| 77 | Args:
|
| 78 | expected_observations: List of expected observation strings
|
| 79 | messages: Agent conversation messages (list of message dicts with 'role' and 'content')
|
| 80 | """
|
| 81 | # Extract actual observations from agent messages
|
| 82 | # User/exit messages (observations) are at indices 3, 5, 7, etc.
|
| 83 | actual_observations = []
|
| 84 | for i in range(len(expected_observations)):
|
| 85 | user_message_index = 3 + (i * 2)
|
| 86 | assert messages[user_message_index]["role"] in ("user", "exit")
|
| 87 | actual_observations.append(messages[user_message_index]["content"])
|
| 88 |
|
| 89 | assert len(actual_observations) == len(expected_observations), (
|
| 90 | f"Expected {len(expected_observations)} observations, got {len(actual_observations)}"
|
| 91 | )
|
| 92 |
|
| 93 | for i, (expected_observation, actual_observation) in enumerate(zip(expected_observations, actual_observations)):
|
| 94 | normalized_actual = normalize_outputs(actual_observation)
|
| 95 | normalized_expected = normalize_outputs(expected_observation)
|
| 96 |
|
| 97 | assert normalized_actual == normalized_expected, (
|
| 98 | f"Step {i + 1} observation mismatch:\nExpected: {repr(normalized_expected)}\nActual: {repr(normalized_actual)}"
|
| 99 | )
|
| 100 |
|
| 101 |
|
| 102 | @pytest.fixture
|
| 103 | def github_test_data():
|
| 104 | """Load GitHub issue test fixtures"""
|
| 105 | return get_test_data("github_issue")
|
| 106 |
|
| 107 |
|
| 108 | @pytest.fixture
|
| 109 | def local_test_data():
|
| 110 | """Load local test fixtures"""
|
| 111 | return get_test_data("local")
|
| 112 |
|