| 1 | """Cache control utilities are mostly for Anthropic models.
|
| 2 | They are used to explicitly set cache control points.
|
| 3 | """
|
| 4 |
|
| 5 | import copy
|
| 6 | import warnings
|
| 7 | from typing import Literal
|
| 8 |
|
| 9 |
|
| 10 | def _get_content_text(entry: dict) -> str | None:
|
| 11 | if entry["content"] is None:
|
| 12 | return None
|
| 13 | if isinstance(entry["content"], str):
|
| 14 | return entry["content"]
|
| 15 | assert len(entry["content"]) == 1, "Expected single message in content"
|
| 16 | return entry["content"][0]["text"]
|
| 17 |
|
| 18 |
|
| 19 | def _clear_cache_control(entry: dict) -> None:
|
| 20 | if isinstance(entry["content"], list):
|
| 21 | assert len(entry["content"]) == 1, "Expected single message in content"
|
| 22 | entry["content"][0].pop("cache_control", None)
|
| 23 | # Note: entry["content"] can be None for assistant messages with only tool_use
|
| 24 | entry.pop("cache_control", None)
|
| 25 |
|
| 26 |
|
| 27 | def _set_cache_control(entry: dict) -> None:
|
| 28 | # Handle None content (e.g., assistant messages with only tool_use)
|
| 29 | if entry["content"] is None:
|
| 30 | entry["cache_control"] = {"type": "ephemeral"}
|
| 31 | return
|
| 32 |
|
| 33 | if not isinstance(entry["content"], list):
|
| 34 | entry["content"] = [ # type: ignore
|
| 35 | {
|
| 36 | "type": "text",
|
| 37 | "text": _get_content_text(entry),
|
| 38 | "cache_control": {"type": "ephemeral"},
|
| 39 | }
|
| 40 | ]
|
| 41 | else:
|
| 42 | entry["content"][0]["cache_control"] = {"type": "ephemeral"}
|
| 43 | if entry["role"] == "tool":
|
| 44 | # Workaround for weird bug
|
| 45 | entry["content"][0].pop("cache_control", None)
|
| 46 | entry["cache_control"] = {"type": "ephemeral"}
|
| 47 |
|
| 48 |
|
| 49 | def set_cache_control(
|
| 50 | messages: list[dict], *, mode: Literal["default_end"] | None = "default_end", last_n_messages_offset: int = 0
|
| 51 | ) -> list[dict]:
|
| 52 | """This messages processor adds manual cache control marks to the messages."""
|
| 53 | if mode is None:
|
| 54 | return messages
|
| 55 | if mode != "default_end":
|
| 56 | raise ValueError(f"Invalid mode: {mode}")
|
| 57 | if last_n_messages_offset:
|
| 58 | warnings.warn("last_n_messages_offset is deprecated and will be removed in the future. It has no effect.")
|
| 59 |
|
| 60 | messages = copy.deepcopy(messages)
|
| 61 | new_messages = []
|
| 62 | for i_entry, entry in enumerate(reversed(messages)):
|
| 63 | _clear_cache_control(entry)
|
| 64 | if i_entry == 0:
|
| 65 | _set_cache_control(entry)
|
| 66 | new_messages.append(entry)
|
| 67 | return list(reversed(new_messages))
|
| 68 |
|