| 1 | from unittest.mock import MagicMock, patch
|
| 2 |
|
| 3 | import pytest
|
| 4 |
|
| 5 | from minisweagent.exceptions import FormatError
|
| 6 | from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
|
| 7 | from minisweagent.models.utils.actions_toolcall import BASH_TOOL
|
| 8 |
|
| 9 |
|
| 10 | class TestLitellmModelConfig:
|
| 11 | def test_default_format_error_template(self):
|
| 12 | assert LitellmModelConfig(model_name="test").format_error_template == "{{ error }}"
|
| 13 |
|
| 14 |
|
| 15 | def _mock_litellm_response(tool_calls):
|
| 16 | mock_response = MagicMock()
|
| 17 | mock_response.choices = [MagicMock()]
|
| 18 | mock_response.choices[0].message.tool_calls = tool_calls
|
| 19 | mock_response.choices[0].message.model_dump.return_value = {"role": "assistant", "content": None}
|
| 20 | mock_response.model_dump.return_value = {}
|
| 21 | return mock_response
|
| 22 |
|
| 23 |
|
| 24 | class TestLitellmModel:
|
| 25 | @patch("minisweagent.models.litellm_model.litellm.completion")
|
| 26 | @patch("minisweagent.models.litellm_model.litellm.cost_calculator.completion_cost")
|
| 27 | def test_query_includes_bash_tool(self, mock_cost, mock_completion):
|
| 28 | tool_call = MagicMock()
|
| 29 | tool_call.function.name = "bash"
|
| 30 | tool_call.function.arguments = '{"command": "echo test"}'
|
| 31 | tool_call.id = "call_1"
|
| 32 | mock_completion.return_value = _mock_litellm_response([tool_call])
|
| 33 | mock_cost.return_value = 0.001
|
| 34 |
|
| 35 | model = LitellmModel(model_name="gpt-4")
|
| 36 | model.query([{"role": "user", "content": "test"}])
|
| 37 |
|
| 38 | mock_completion.assert_called_once()
|
| 39 | assert mock_completion.call_args.kwargs["tools"] == [BASH_TOOL]
|
| 40 |
|
| 41 | @patch("minisweagent.models.litellm_model.litellm.completion")
|
| 42 | @patch("minisweagent.models.litellm_model.litellm.cost_calculator.completion_cost")
|
| 43 | def test_parse_actions_valid_tool_call(self, mock_cost, mock_completion):
|
| 44 | tool_call = MagicMock()
|
| 45 | tool_call.function.name = "bash"
|
| 46 | tool_call.function.arguments = '{"command": "ls -la"}'
|
| 47 | tool_call.id = "call_abc"
|
| 48 | mock_completion.return_value = _mock_litellm_response([tool_call])
|
| 49 | mock_cost.return_value = 0.001
|
| 50 |
|
| 51 | model = LitellmModel(model_name="gpt-4")
|
| 52 | result = model.query([{"role": "user", "content": "list files"}])
|
| 53 | assert result["extra"]["actions"] == [{"command": "ls -la", "tool_call_id": "call_abc"}]
|
| 54 |
|
| 55 | @patch("minisweagent.models.litellm_model.litellm.completion")
|
| 56 | @patch("minisweagent.models.litellm_model.litellm.cost_calculator.completion_cost")
|
| 57 | def test_parse_actions_no_tool_calls_raises(self, mock_cost, mock_completion):
|
| 58 | mock_completion.return_value = _mock_litellm_response(None)
|
| 59 | mock_cost.return_value = 0.001
|
| 60 |
|
| 61 | model = LitellmModel(model_name="gpt-4")
|
| 62 | with pytest.raises(FormatError):
|
| 63 | model.query([{"role": "user", "content": "test"}])
|
| 64 |
|
| 65 | def test_format_observation_messages(self):
|
| 66 | model = LitellmModel(model_name="gpt-4", observation_template="{{ output.output }}")
|
| 67 | message = {"extra": {"actions": [{"command": "echo test", "tool_call_id": "call_1"}]}}
|
| 68 | outputs = [{"output": "test output", "returncode": 0}]
|
| 69 | result = model.format_observation_messages(message, outputs)
|
| 70 | assert len(result) == 1
|
| 71 | assert result[0]["role"] == "tool"
|
| 72 | assert result[0]["tool_call_id"] == "call_1"
|
| 73 | assert result[0]["content"] == "test output"
|
| 74 |
|
| 75 | def test_format_observation_messages_no_actions(self):
|
| 76 | model = LitellmModel(model_name="gpt-4")
|
| 77 | result = model.format_observation_messages({"extra": {}}, [])
|
| 78 | assert result == []
|
| 79 |
|