| 1 | import json
|
| 2 | import tempfile
|
| 3 | from pathlib import Path
|
| 4 | from unittest.mock import Mock, patch
|
| 5 |
|
| 6 | import litellm
|
| 7 | import pytest
|
| 8 |
|
| 9 | from minisweagent.models import GLOBAL_MODEL_STATS
|
| 10 | from minisweagent.models.litellm_textbased_model import LitellmTextbasedModel
|
| 11 |
|
| 12 |
|
| 13 | def test_authentication_error_enhanced_message():
|
| 14 | """Test that AuthenticationError gets enhanced with config set instruction."""
|
| 15 | model = LitellmTextbasedModel(model_name="gpt-4")
|
| 16 |
|
| 17 | # Create a mock exception that behaves like AuthenticationError
|
| 18 | original_error = Mock(spec=litellm.exceptions.AuthenticationError)
|
| 19 | original_error.message = "Invalid API key"
|
| 20 |
|
| 21 | with patch("litellm.completion") as mock_completion:
|
| 22 | # Make completion raise the mock error
|
| 23 | def side_effect(*args, **kwargs):
|
| 24 | raise litellm.exceptions.AuthenticationError("Invalid API key", llm_provider="openai", model="gpt-4")
|
| 25 |
|
| 26 | mock_completion.side_effect = side_effect
|
| 27 |
|
| 28 | with pytest.raises(litellm.exceptions.AuthenticationError) as exc_info:
|
| 29 | model._query([{"role": "user", "content": "test"}])
|
| 30 |
|
| 31 | # Check that the error message was enhanced
|
| 32 | assert "You can permanently set your API key with `mini-extra config set KEY VALUE`." in str(exc_info.value)
|
| 33 |
|
| 34 |
|
| 35 | def test_model_registry_loading():
|
| 36 | """Test that custom model registry is loaded and registered when provided."""
|
| 37 | model_costs = {
|
| 38 | "my-custom-model": {
|
| 39 | "max_tokens": 4096,
|
| 40 | "input_cost_per_token": 0.0001,
|
| 41 | "output_cost_per_token": 0.0002,
|
| 42 | "litellm_provider": "openai",
|
| 43 | "mode": "chat",
|
| 44 | }
|
| 45 | }
|
| 46 |
|
| 47 | with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
| 48 | json.dump(model_costs, f)
|
| 49 | registry_path = f.name
|
| 50 |
|
| 51 | try:
|
| 52 | with patch("litellm.utils.register_model") as mock_register:
|
| 53 | _model = LitellmTextbasedModel(model_name="my-custom-model", litellm_model_registry=Path(registry_path))
|
| 54 |
|
| 55 | # Verify register_model was called with the correct data
|
| 56 | mock_register.assert_called_once_with(model_costs)
|
| 57 | except Exception as e:
|
| 58 | print(e)
|
| 59 | raise e
|
| 60 | finally:
|
| 61 | Path(registry_path).unlink()
|
| 62 |
|
| 63 |
|
| 64 | def test_model_registry_none():
|
| 65 | """Test that no registry loading occurs when litellm_model_registry is None."""
|
| 66 | with patch("litellm.register_model") as mock_register:
|
| 67 | _model = LitellmTextbasedModel(model_name="gpt-4", litellm_model_registry=None)
|
| 68 |
|
| 69 | # Verify register_model was not called
|
| 70 | mock_register.assert_not_called()
|
| 71 |
|
| 72 |
|
| 73 | def test_model_registry_not_provided():
|
| 74 | """Test that no registry loading occurs when litellm_model_registry is not provided."""
|
| 75 | with patch("litellm.register_model") as mock_register:
|
| 76 | _model = LitellmTextbasedModel(model_name="gpt-4o")
|
| 77 |
|
| 78 | # Verify register_model was not called
|
| 79 | mock_register.assert_not_called()
|
| 80 |
|
| 81 |
|
| 82 | def test_litellm_model_cost_tracking_ignore_errors():
|
| 83 | """Test that models work with cost_tracking='ignore_errors'."""
|
| 84 | model = LitellmTextbasedModel(model_name="gpt-4o", cost_tracking="ignore_errors")
|
| 85 |
|
| 86 | initial_cost = GLOBAL_MODEL_STATS.cost
|
| 87 |
|
| 88 | with patch("litellm.completion") as mock_completion:
|
| 89 | mock_response = Mock()
|
| 90 | mock_message = Mock()
|
| 91 | mock_message.content = "```mswea_bash_command\necho test\n```"
|
| 92 | mock_message.model_dump.return_value = {
|
| 93 | "role": "assistant",
|
| 94 | "content": "```mswea_bash_command\necho test\n```",
|
| 95 | }
|
| 96 | mock_response.choices = [Mock(message=mock_message)]
|
| 97 | mock_response.model_dump.return_value = {"test": "response"}
|
| 98 | mock_completion.return_value = mock_response
|
| 99 |
|
| 100 | with patch("litellm.cost_calculator.completion_cost", side_effect=ValueError("Model not found")):
|
| 101 | messages = [{"role": "user", "content": "test"}]
|
| 102 | result = model.query(messages)
|
| 103 |
|
| 104 | assert result["content"] == "```mswea_bash_command\necho test\n```"
|
| 105 | assert result["extra"]["actions"] == [{"command": "echo test"}]
|
| 106 | assert GLOBAL_MODEL_STATS.cost == initial_cost
|
| 107 |
|
| 108 |
|
| 109 | def test_litellm_model_cost_validation_zero_cost():
|
| 110 | """Test that zero cost raises error when cost tracking is enabled."""
|
| 111 | model = LitellmTextbasedModel(model_name="gpt-4o")
|
| 112 |
|
| 113 | with patch("litellm.completion") as mock_completion:
|
| 114 | mock_response = Mock()
|
| 115 | mock_response.choices = [Mock(message=Mock(content="Test response"))]
|
| 116 | mock_response.model_dump.return_value = {"test": "response"}
|
| 117 | mock_completion.return_value = mock_response
|
| 118 |
|
| 119 | with patch("litellm.cost_calculator.completion_cost", return_value=0.0):
|
| 120 | messages = [{"role": "user", "content": "test"}]
|
| 121 |
|
| 122 | with pytest.raises(RuntimeError) as exc_info:
|
| 123 | model.query(messages)
|
| 124 |
|
| 125 | assert "Cost must be > 0.0, got 0.0" in str(exc_info.value)
|
| 126 | assert "MSWEA_COST_TRACKING='ignore_errors'" in str(exc_info.value)
|
| 127 |
|