| 1 | """Test that cache control is actually applied when using anthropic models through get_model()."""
|
| 2 |
|
| 3 | import copy
|
| 4 | from unittest.mock import MagicMock, patch
|
| 5 |
|
| 6 | import pytest
|
| 7 |
|
| 8 | from minisweagent.models import get_model
|
| 9 |
|
| 10 |
|
| 11 | def _mock_litellm_completion(response_content="```mswea_bash_command\necho test\n```"):
|
| 12 | """Helper to create consistent litellm mocks. Response must include bash block for parse_action."""
|
| 13 | mock_response = MagicMock()
|
| 14 | mock_response.choices = [MagicMock()]
|
| 15 | mock_response.choices[0].message.content = response_content
|
| 16 | mock_response.model_dump.return_value = {"mock": "response"}
|
| 17 | return mock_response
|
| 18 |
|
| 19 |
|
| 20 | def test_sonnet_4_cache_control_integration():
|
| 21 | """Test that get_model('sonnet-4') results in cache control being applied when querying."""
|
| 22 | messages = [
|
| 23 | {"role": "user", "content": "Hello, how are you?"},
|
| 24 | {"role": "assistant", "content": "I'm doing well!"},
|
| 25 | {"role": "user", "content": "Can you help me with coding?"},
|
| 26 | ]
|
| 27 |
|
| 28 | with patch("minisweagent.models.litellm_model.litellm.completion") as mock_completion:
|
| 29 | mock_completion.return_value = _mock_litellm_completion("```mswea_bash_command\necho 'I can help!'\n```")
|
| 30 |
|
| 31 | with patch("minisweagent.models.litellm_model.litellm.cost_calculator.completion_cost") as mock_cost:
|
| 32 | mock_cost.return_value = 0.001
|
| 33 |
|
| 34 | # This is the key test: get_model with anthropic name should enable cache control
|
| 35 | model = get_model("sonnet-4")
|
| 36 | model.query(messages)
|
| 37 |
|
| 38 | # Verify that cache control was applied to the messages sent to the API
|
| 39 | mock_completion.assert_called_once()
|
| 40 | call_kwargs = mock_completion.call_args.kwargs
|
| 41 |
|
| 42 | # Check the messages that were actually sent to litellm.completion
|
| 43 | sent_messages = call_kwargs["messages"]
|
| 44 |
|
| 45 | # Only the last message should have cache control applied
|
| 46 | assert len(sent_messages) == 3
|
| 47 |
|
| 48 | # First two messages should not have cache control
|
| 49 | assert sent_messages[0]["content"] == "Hello, how are you?"
|
| 50 | assert sent_messages[1]["content"] == "I'm doing well!"
|
| 51 |
|
| 52 | # Last message should have cache control
|
| 53 | last_message = sent_messages[2]
|
| 54 | assert isinstance(last_message["content"], list)
|
| 55 | assert last_message["content"][0]["cache_control"] == {"type": "ephemeral"}
|
| 56 | assert last_message["content"][0]["type"] == "text"
|
| 57 | assert last_message["content"][0]["text"] == "Can you help me with coding?"
|
| 58 |
|
| 59 |
|
| 60 | @pytest.mark.parametrize(
|
| 61 | "model_name",
|
| 62 | [
|
| 63 | "sonnet-4",
|
| 64 | "claude-sonnet",
|
| 65 | "anthropic/claude",
|
| 66 | "opus-latest",
|
| 67 | ],
|
| 68 | )
|
| 69 | def test_get_model_anthropic_applies_cache_control(model_name):
|
| 70 | """Test that using get_model with anthropic model names results in cache control being applied."""
|
| 71 | messages = [
|
| 72 | {"role": "system", "content": "You are a helpful assistant."},
|
| 73 | {"role": "user", "content": "Hello!"},
|
| 74 | {"role": "assistant", "content": "Hi there!"},
|
| 75 | {"role": "user", "content": "Help me code."},
|
| 76 | ]
|
| 77 |
|
| 78 | with patch("minisweagent.models.litellm_model.litellm.completion") as mock_completion:
|
| 79 | mock_completion.return_value = _mock_litellm_completion("```mswea_bash_command\necho 'help code'\n```")
|
| 80 |
|
| 81 | with patch("minisweagent.models.litellm_model.litellm.cost_calculator.completion_cost") as mock_cost:
|
| 82 | mock_cost.return_value = 0.001
|
| 83 |
|
| 84 | # Get model through get_model - this should auto-configure cache control
|
| 85 | model = get_model(model_name)
|
| 86 |
|
| 87 | # Call query with a copy of messages (to avoid mutation issues)
|
| 88 | model.query(copy.deepcopy(messages))
|
| 89 |
|
| 90 | # Verify completion was called
|
| 91 | mock_completion.assert_called_once()
|
| 92 | call_args = mock_completion.call_args
|
| 93 |
|
| 94 | # Check that cache control was applied to the messages passed to litellm
|
| 95 | passed_messages = call_args.kwargs["messages"]
|
| 96 |
|
| 97 | # Only the last message should have cache control
|
| 98 | assert len(passed_messages) == 4, f"Expected 4 messages for {model_name}"
|
| 99 |
|
| 100 | # First three messages should not have cache control
|
| 101 | assert passed_messages[0]["content"] == "You are a helpful assistant.", (
|
| 102 | f"System message content should be preserved for {model_name}"
|
| 103 | )
|
| 104 | assert passed_messages[1]["content"] == "Hello!", (
|
| 105 | f"First user message content should be preserved for {model_name}"
|
| 106 | )
|
| 107 | assert passed_messages[2]["content"] == "Hi there!", (
|
| 108 | f"Assistant message content should be preserved for {model_name}"
|
| 109 | )
|
| 110 |
|
| 111 | # Last message should have cache control
|
| 112 | last_message = passed_messages[3]
|
| 113 | assert isinstance(last_message["content"], list), f"Last message should have list content for {model_name}"
|
| 114 | assert len(last_message["content"]) == 1, f"Last message should have single content item for {model_name}"
|
| 115 |
|
| 116 | content_item = last_message["content"][0]
|
| 117 | assert content_item["type"] == "text", f"Content should be text type for {model_name}"
|
| 118 | assert content_item["cache_control"] == {"type": "ephemeral"}, f"Cache control missing for {model_name}"
|
| 119 | assert content_item["text"] == "Help me code.", f"Text content should be preserved for {model_name}"
|
| 120 |
|
| 121 |
|
| 122 | @pytest.mark.parametrize(
|
| 123 | "model_name",
|
| 124 | [
|
| 125 | "gpt-4",
|
| 126 | "gpt-3.5-turbo",
|
| 127 | "llama2",
|
| 128 | ],
|
| 129 | )
|
| 130 | def test_get_model_non_anthropic_no_cache_control(model_name):
|
| 131 | """Test that non-anthropic models don't get cache control applied."""
|
| 132 | messages = [
|
| 133 | {"role": "user", "content": "Hello!"},
|
| 134 | ]
|
| 135 |
|
| 136 | with patch("minisweagent.models.litellm_model.litellm.completion") as mock_completion:
|
| 137 | mock_completion.return_value = _mock_litellm_completion("```mswea_bash_command\necho hello\n```")
|
| 138 |
|
| 139 | with patch("minisweagent.models.litellm_model.litellm.cost_calculator.completion_cost") as mock_cost:
|
| 140 | mock_cost.return_value = 0.001
|
| 141 |
|
| 142 | # Get model through get_model - should NOT auto-configure cache control
|
| 143 | model = get_model(model_name)
|
| 144 |
|
| 145 | # Call query
|
| 146 | model.query(copy.deepcopy(messages))
|
| 147 |
|
| 148 | # Verify completion was called
|
| 149 | mock_completion.assert_called_once()
|
| 150 | call_args = mock_completion.call_args
|
| 151 |
|
| 152 | # Check that messages were NOT modified with cache control
|
| 153 | passed_messages = call_args.kwargs["messages"]
|
| 154 |
|
| 155 | # The user message should still be a simple string, not transformed
|
| 156 | user_msg = passed_messages[0]
|
| 157 | assert user_msg["role"] == "user"
|
| 158 | assert user_msg["content"] == "Hello!", f"Content should remain as string for {model_name}"
|
| 159 | assert "cache_control" not in user_msg, f"No cache_control should be present for {model_name}"
|
| 160 |
|