| 1 | import re
|
| 2 | from unittest.mock import patch
|
| 3 |
|
| 4 | from minisweagent.models.test_models import DeterministicModel, make_output
|
| 5 | from minisweagent.run.mini import DEFAULT_CONFIG_FILE, main
|
| 6 | from tests.conftest import assert_observations_match
|
| 7 |
|
| 8 |
|
| 9 | def _make_model_from_fixture(text_outputs: list[str], cost_per_call: float = 1.0, **kwargs) -> DeterministicModel:
|
| 10 | """Create a DeterministicModel from trajectory fixture data (raw text outputs)."""
|
| 11 |
|
| 12 | def parse_command(text: str) -> list[dict]:
|
| 13 | match = re.search(r"```mswea_bash_command\s*\n(.*?)\n```", text, re.DOTALL)
|
| 14 | return [{"command": match.group(1)}] if match else []
|
| 15 |
|
| 16 | return DeterministicModel(
|
| 17 | outputs=[make_output(text, parse_command(text), cost=cost_per_call) for text in text_outputs],
|
| 18 | cost_per_call=cost_per_call,
|
| 19 | **kwargs,
|
| 20 | )
|
| 21 |
|
| 22 |
|
| 23 | def test_local_end_to_end(local_test_data):
|
| 24 | """Test the complete flow from CLI to final result using real environment but deterministic model"""
|
| 25 |
|
| 26 | model_responses = local_test_data["model_responses"]
|
| 27 | expected_observations = local_test_data["expected_observations"]
|
| 28 |
|
| 29 | with (
|
| 30 | patch("minisweagent.run.mini.configure_if_first_time"),
|
| 31 | patch("minisweagent.models.litellm_model.LitellmModel") as mock_model_class,
|
| 32 | patch("minisweagent.agents.interactive._prompt_session.prompt", side_effect=lambda *a, **kw: ""),
|
| 33 | patch("minisweagent.agents.interactive._multiline_prompt_session.prompt", side_effect=lambda *a, **kw: ""),
|
| 34 | patch("builtins.input", return_value=""), # For LimitsExceeded handling
|
| 35 | ):
|
| 36 | mock_model_class.return_value = _make_model_from_fixture(model_responses)
|
| 37 | agent = main(
|
| 38 | model_name="tardis",
|
| 39 | config_spec=[str(DEFAULT_CONFIG_FILE)],
|
| 40 | yolo=True,
|
| 41 | task="Blah blah blah",
|
| 42 | output=None,
|
| 43 | cost_limit=10,
|
| 44 | model_class=None,
|
| 45 | ) # type: ignore
|
| 46 |
|
| 47 | assert agent is not None
|
| 48 | messages = agent.messages
|
| 49 |
|
| 50 | # Verify we have the right number of messages
|
| 51 | # Should be: system + user (initial) + (assistant + user) * number_of_steps
|
| 52 | expected_total_messages = 2 + (len(model_responses) * 2)
|
| 53 | assert len(messages) == expected_total_messages, f"Expected {expected_total_messages} messages, got {len(messages)}"
|
| 54 |
|
| 55 | assert_observations_match(expected_observations, messages)
|
| 56 |
|
| 57 | assert agent.n_calls == len(model_responses), f"Expected {len(model_responses)} steps, got {agent.n_calls}"
|
| 58 |
|