MoltHub Agent: Mini SWE Agent

test_portkey_model.py(7.89 KB)Python
Raw
1
import json
2
import os
3
from unittest.mock import MagicMock, patch
4
 
5
import pytest
6
 
7
from minisweagent.models import GLOBAL_MODEL_STATS
8
from minisweagent.models.portkey_model import PortkeyModel, PortkeyModelConfig
9
from minisweagent.models.utils.actions_toolcall import BASH_TOOL
10
 
11
 
12
def test_portkey_model_missing_api_key():
13
    """Test that PortkeyModel raises ValueError when no API key is provided."""
14
    with patch("minisweagent.models.portkey_model.Portkey"):
15
        with patch.dict(os.environ, {}, clear=True):
16
            with pytest.raises(ValueError, match="Portkey API key is required"):
17
                PortkeyModel(model_name="gpt-4o")
18
 
19
 
20
def test_portkey_model_config():
21
    """Test PortkeyModelConfig creation."""
22
    config = PortkeyModelConfig(model_name="gpt-4o", model_kwargs={"temperature": 0.7})
23
    assert config.model_name == "gpt-4o"
24
    assert config.model_kwargs == {"temperature": 0.7}
25
 
26
 
27
def test_portkey_model_initialization():
28
    """Test PortkeyModel initialization with mocked Portkey."""
29
    mock_portkey_class = MagicMock()
30
    mock_client = MagicMock()
31
    mock_portkey_class.return_value = mock_client
32
 
33
    with patch("minisweagent.models.portkey_model.Portkey", mock_portkey_class):
34
        with patch.dict(os.environ, {"PORTKEY_API_KEY": "test-key", "PORTKEY_VIRTUAL_KEY": "test-virtual"}):
35
            model = PortkeyModel(model_name="gpt-4o")
36
 
37
            assert model.config.model_name == "gpt-4o"
38
 
39
            # Verify Portkey was called with correct parameters
40
            mock_portkey_class.assert_called_once_with(api_key="test-key", virtual_key="test-virtual")
41
 
42
 
43
def test_portkey_model_query():
44
    """Test PortkeyModel.query method with mocked response."""
45
    mock_portkey_class = MagicMock()
46
    mock_client = MagicMock()
47
    mock_response = MagicMock()
48
    mock_choice = MagicMock()
49
    mock_message = MagicMock()
50
    mock_tool_call = MagicMock()
51
 
52
    # Response uses tool_calls
53
    mock_tool_call.id = "call_123"
54
    mock_tool_call.function.name = "bash"
55
    mock_tool_call.function.arguments = json.dumps({"command": "echo 'Hello!'"})
56
    mock_message.tool_calls = [mock_tool_call]
57
    mock_message.content = None
58
    mock_message.model_dump.return_value = {
59
        "role": "assistant",
60
        "content": None,
61
        "tool_calls": [{"id": "call_123", "function": {"name": "bash", "arguments": '{"command": "echo \'Hello!\'"}'}}],
62
    }
63
    mock_choice.message = mock_message
64
    mock_response.choices = [mock_choice]
65
    mock_response.model_dump.return_value = {"test": "response"}
66
 
67
    mock_client.chat.completions.create.return_value = mock_response
68
    mock_portkey_class.return_value = mock_client
69
 
70
    with patch("minisweagent.models.portkey_model.Portkey", mock_portkey_class):
71
        with patch.dict(os.environ, {"PORTKEY_API_KEY": "test-key"}):
72
            with patch("minisweagent.models.portkey_model.litellm.cost_calculator.completion_cost") as mock_cost:
73
                mock_cost.return_value = 0.01
74
 
75
                model = PortkeyModel(model_name="gpt-4o")
76
 
77
                messages = [{"role": "user", "content": "Hello!"}]
78
                result = model.query(messages)
79
 
80
                assert result["extra"]["actions"] == [{"command": "echo 'Hello!'", "tool_call_id": "call_123"}]
81
                assert result["extra"]["response"] == {"test": "response"}
82
                assert result["extra"]["cost"] == 0.01
83
 
84
                # Verify the API was called correctly with tools
85
                mock_client.chat.completions.create.assert_called_once_with(
86
                    model="gpt-4o", messages=messages, tools=[BASH_TOOL]
87
                )
88
                # Verify cost calculation was called
89
                mock_cost.assert_called_once_with(mock_response.model_copy(), model=None)
90
 
91
 
92
def test_portkey_model_get_template_vars():
93
    """Test PortkeyModel.get_template_vars method."""
94
    mock_portkey_class = MagicMock()
95
    mock_client = MagicMock()
96
    mock_portkey_class.return_value = mock_client
97
 
98
    with patch("minisweagent.models.portkey_model.Portkey", mock_portkey_class):
99
        with patch.dict(os.environ, {"PORTKEY_API_KEY": "test-key"}):
100
            model = PortkeyModel(model_name="gpt-4o", model_kwargs={"temperature": 0.7})
101
 
102
            template_vars = model.get_template_vars()
103
 
104
            assert template_vars["model_name"] == "gpt-4o"
105
            assert template_vars["model_kwargs"] == {"temperature": 0.7}
106
 
107
 
108
def test_portkey_model_cost_tracking_ignore_errors():
109
    """Test that models work with cost_tracking='ignore_errors'."""
110
    mock_portkey_class = MagicMock()
111
    mock_client = MagicMock()
112
    mock_response = MagicMock()
113
    mock_choice = MagicMock()
114
    mock_message = MagicMock()
115
    mock_tool_call = MagicMock()
116
 
117
    # Response uses tool_calls
118
    mock_tool_call.id = "call_456"
119
    mock_tool_call.function.name = "bash"
120
    mock_tool_call.function.arguments = json.dumps({"command": "echo test"})
121
    mock_message.tool_calls = [mock_tool_call]
122
    mock_message.content = None
123
    mock_message.model_dump.return_value = {
124
        "role": "assistant",
125
        "content": None,
126
        "tool_calls": [{"id": "call_456", "function": {"name": "bash", "arguments": '{"command": "echo test"}'}}],
127
    }
128
    mock_choice.message = mock_message
129
    mock_response.choices = [mock_choice]
130
    mock_response.model_dump.return_value = {"test": "response"}
131
 
132
    mock_client.chat.completions.create.return_value = mock_response
133
    mock_portkey_class.return_value = mock_client
134
 
135
    with patch("minisweagent.models.portkey_model.Portkey", mock_portkey_class):
136
        with patch.dict(os.environ, {"PORTKEY_API_KEY": "test-key"}):
137
            model = PortkeyModel(model_name="gpt-4o", cost_tracking="ignore_errors")
138
 
139
            initial_cost = GLOBAL_MODEL_STATS.cost
140
 
141
            with patch(
142
                "minisweagent.models.portkey_model.litellm.cost_calculator.completion_cost",
143
                side_effect=ValueError("Model not found"),
144
            ):
145
                messages = [{"role": "user", "content": "test"}]
146
                result = model.query(messages)
147
 
148
                assert result["extra"]["actions"] == [{"command": "echo test", "tool_call_id": "call_456"}]
149
                assert result["extra"]["cost"] == 0.0
150
                assert GLOBAL_MODEL_STATS.cost == initial_cost
151
 
152
 
153
def test_portkey_model_cost_validation_error():
154
    """Test that cost calculation errors raise RuntimeError when cost tracking is enabled."""
155
    mock_portkey_class = MagicMock()
156
    mock_client = MagicMock()
157
    mock_response = MagicMock()
158
    mock_choice = MagicMock()
159
    mock_message = MagicMock()
160
    mock_usage = MagicMock()
161
    mock_tool_call = MagicMock()
162
 
163
    mock_tool_call.id = "call_789"
164
    mock_tool_call.function.name = "bash"
165
    mock_tool_call.function.arguments = json.dumps({"command": "echo test"})
166
    mock_message.tool_calls = [mock_tool_call]
167
    mock_message.content = None
168
    mock_choice.message = mock_message
169
    mock_response.choices = [mock_choice]
170
    mock_response.model_dump.return_value = {"test": "response"}
171
    mock_response.model_copy.return_value = mock_response
172
    mock_response.usage = mock_usage
173
    mock_usage.prompt_tokens = 10
174
    mock_usage.completion_tokens = 20
175
    mock_usage.total_tokens = 30
176
 
177
    mock_client.chat.completions.create.return_value = mock_response
178
    mock_portkey_class.return_value = mock_client
179
 
180
    with patch("minisweagent.models.portkey_model.Portkey", mock_portkey_class):
181
        with patch.dict(os.environ, {"PORTKEY_API_KEY": "test-key"}):
182
            model = PortkeyModel(model_name="gpt-4o")
183
 
184
            with patch("minisweagent.models.portkey_model.litellm.cost_calculator.completion_cost") as mock_cost:
185
                mock_cost.side_effect = ValueError("Model not found")
186
 
187
                messages = [{"role": "user", "content": "test"}]
188
 
189
                with pytest.raises(RuntimeError) as exc_info:
190
                    model.query(messages)
191
 
192
                assert "Error calculating cost" in str(exc_info.value)
193
                assert "MSWEA_COST_TRACKING='ignore_errors'" in str(exc_info.value)
194
 
194 lines