| 1 | import logging
|
| 2 | import time
|
| 3 | from typing import Any
|
| 4 |
|
| 5 | from pydantic import BaseModel
|
| 6 |
|
| 7 | from minisweagent.models import GLOBAL_MODEL_STATS
|
| 8 | from minisweagent.models.utils.actions_text import format_observation_messages
|
| 9 | from minisweagent.models.utils.actions_toolcall import format_toolcall_observation_messages
|
| 10 | from minisweagent.models.utils.actions_toolcall_response import (
|
| 11 | format_toolcall_observation_messages as format_response_api_observation_messages,
|
| 12 | )
|
| 13 | from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
|
| 14 |
|
| 15 |
|
| 16 | def make_output(content: str, actions: list[dict], cost: float = 1.0) -> dict:
|
| 17 | """Helper to create an output dict for DeterministicModel.
|
| 18 |
|
| 19 | Args:
|
| 20 | content: The response content string
|
| 21 | actions: List of action dicts, e.g., [{"command": "echo hello"}]
|
| 22 | cost: Cost to report for this output (default 1.0)
|
| 23 | """
|
| 24 | return {
|
| 25 | "role": "assistant",
|
| 26 | "content": content,
|
| 27 | "extra": {"actions": actions, "cost": cost, "timestamp": time.time()},
|
| 28 | }
|
| 29 |
|
| 30 |
|
| 31 | def make_toolcall_output(content: str | None, tool_calls: list[dict], actions: list[dict]) -> dict:
|
| 32 | """Helper to create a toolcall output dict for DeterministicToolcallModel.
|
| 33 |
|
| 34 | Args:
|
| 35 | content: Optional text content (can be None for tool-only responses)
|
| 36 | tool_calls: List of tool call dicts in OpenAI format
|
| 37 | actions: List of parsed action dicts, e.g., [{"command": "echo hello", "tool_call_id": "call_123"}]
|
| 38 | """
|
| 39 | return {
|
| 40 | "role": "assistant",
|
| 41 | "content": content,
|
| 42 | "tool_calls": tool_calls,
|
| 43 | "extra": {"actions": actions, "cost": 1.0, "timestamp": time.time()},
|
| 44 | }
|
| 45 |
|
| 46 |
|
| 47 | def make_response_api_output(content: str | None, actions: list[dict]) -> dict:
|
| 48 | """Helper to create an output dict for DeterministicResponseAPIToolcallModel.
|
| 49 |
|
| 50 | Args:
|
| 51 | content: Optional text content (can be None for tool-only responses)
|
| 52 | actions: List of action dicts with 'command' and 'tool_call_id' keys
|
| 53 | """
|
| 54 | output_items = []
|
| 55 | if content:
|
| 56 | output_items.append(
|
| 57 | {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]}
|
| 58 | )
|
| 59 | for action in actions:
|
| 60 | output_items.append(
|
| 61 | {
|
| 62 | "type": "function_call",
|
| 63 | "call_id": action["tool_call_id"],
|
| 64 | "name": "bash",
|
| 65 | "arguments": f'{{"command": "{action["command"]}"}}',
|
| 66 | }
|
| 67 | )
|
| 68 | return {
|
| 69 | "object": "response",
|
| 70 | "output": output_items,
|
| 71 | "extra": {"actions": actions, "cost": 1.0, "timestamp": time.time()},
|
| 72 | }
|
| 73 |
|
| 74 |
|
| 75 | def _process_test_actions(actions: list[dict]) -> bool:
|
| 76 | """Process special test actions. Returns True if the query should be retried."""
|
| 77 | for action in actions:
|
| 78 | if "raise" in action:
|
| 79 | raise action["raise"]
|
| 80 | cmd = action.get("command", "")
|
| 81 | if cmd.startswith("/sleep "):
|
| 82 | time.sleep(float(cmd.split("/sleep ")[1]))
|
| 83 | return True
|
| 84 | if cmd.startswith("/warning"):
|
| 85 | logging.warning(cmd.split("/warning")[1])
|
| 86 | return True
|
| 87 | return False
|
| 88 |
|
| 89 |
|
| 90 | class DeterministicModelConfig(BaseModel):
|
| 91 | outputs: list[dict]
|
| 92 | """List of exact output messages to return in sequence. Each dict should have 'role', 'content', and 'extra' (with 'actions')."""
|
| 93 | model_name: str = "deterministic"
|
| 94 | cost_per_call: float = 1.0
|
| 95 | observation_template: str = (
|
| 96 | "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
| 97 | "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
| 98 | )
|
| 99 | """Template used to render the observation after executing an action."""
|
| 100 | multimodal_regex: str = ""
|
| 101 | """Regex to extract multimodal content. Empty string disables multimodal processing."""
|
| 102 |
|
| 103 |
|
| 104 | class DeterministicModel:
|
| 105 | def __init__(self, **kwargs):
|
| 106 | """Initialize with a list of output messages to return in sequence."""
|
| 107 | self.config = DeterministicModelConfig(**kwargs)
|
| 108 | self.current_index = -1
|
| 109 |
|
| 110 | def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
| 111 | self.current_index += 1
|
| 112 | output = self.config.outputs[self.current_index]
|
| 113 | if _process_test_actions(output.get("extra", {}).get("actions", [])):
|
| 114 | return self.query(messages, **kwargs)
|
| 115 | GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
|
| 116 | return output
|
| 117 |
|
| 118 | def format_message(self, **kwargs) -> dict:
|
| 119 | return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
| 120 |
|
| 121 | def format_observation_messages(
|
| 122 | self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
| 123 | ) -> list[dict]:
|
| 124 | """Format execution outputs into observation messages."""
|
| 125 | return format_observation_messages(
|
| 126 | outputs,
|
| 127 | observation_template=self.config.observation_template,
|
| 128 | template_vars=template_vars,
|
| 129 | multimodal_regex=self.config.multimodal_regex,
|
| 130 | )
|
| 131 |
|
| 132 | def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
| 133 | return self.config.model_dump()
|
| 134 |
|
| 135 | def serialize(self) -> dict:
|
| 136 | return {
|
| 137 | "info": {
|
| 138 | "config": {
|
| 139 | "model": self.config.model_dump(mode="json"),
|
| 140 | "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
| 141 | },
|
| 142 | }
|
| 143 | }
|
| 144 |
|
| 145 |
|
| 146 | class DeterministicToolcallModelConfig(BaseModel):
|
| 147 | outputs: list[dict]
|
| 148 | """List of exact output messages with tool_calls to return in sequence."""
|
| 149 | model_name: str = "deterministic_toolcall"
|
| 150 | cost_per_call: float = 1.0
|
| 151 | observation_template: str = (
|
| 152 | "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
| 153 | "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
| 154 | )
|
| 155 | """Template used to render the observation after executing an action."""
|
| 156 | multimodal_regex: str = ""
|
| 157 | """Regex to extract multimodal content. Empty string disables multimodal processing."""
|
| 158 |
|
| 159 |
|
| 160 | class DeterministicToolcallModel:
|
| 161 | def __init__(self, **kwargs):
|
| 162 | """Initialize with a list of toolcall output messages to return in sequence."""
|
| 163 | self.config = DeterministicToolcallModelConfig(**kwargs)
|
| 164 | self.current_index = -1
|
| 165 |
|
| 166 | def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
| 167 | self.current_index += 1
|
| 168 | output = self.config.outputs[self.current_index]
|
| 169 | if _process_test_actions(output.get("extra", {}).get("actions", [])):
|
| 170 | return self.query(messages, **kwargs)
|
| 171 | GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
|
| 172 | return output
|
| 173 |
|
| 174 | def format_message(self, **kwargs) -> dict:
|
| 175 | return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
|
| 176 |
|
| 177 | def format_observation_messages(
|
| 178 | self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
| 179 | ) -> list[dict]:
|
| 180 | """Format execution outputs into tool result messages."""
|
| 181 | actions = message.get("extra", {}).get("actions", [])
|
| 182 | return format_toolcall_observation_messages(
|
| 183 | actions=actions,
|
| 184 | outputs=outputs,
|
| 185 | observation_template=self.config.observation_template,
|
| 186 | template_vars=template_vars,
|
| 187 | multimodal_regex=self.config.multimodal_regex,
|
| 188 | )
|
| 189 |
|
| 190 | def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
| 191 | return self.config.model_dump()
|
| 192 |
|
| 193 | def serialize(self) -> dict:
|
| 194 | return {
|
| 195 | "info": {
|
| 196 | "config": {
|
| 197 | "model": self.config.model_dump(mode="json"),
|
| 198 | "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
| 199 | },
|
| 200 | }
|
| 201 | }
|
| 202 |
|
| 203 |
|
| 204 | class DeterministicResponseAPIToolcallModelConfig(BaseModel):
|
| 205 | outputs: list[dict]
|
| 206 | """List of exact Response API output messages to return in sequence."""
|
| 207 | model_name: str = "deterministic_response_api_toolcall"
|
| 208 | cost_per_call: float = 1.0
|
| 209 | observation_template: str = (
|
| 210 | "{% if output.exception_info %}<exception>{{output.exception_info}}</exception>\n{% endif %}"
|
| 211 | "<returncode>{{output.returncode}}</returncode>\n<output>\n{{output.output}}</output>"
|
| 212 | )
|
| 213 | """Template used to render the observation after executing an action."""
|
| 214 | multimodal_regex: str = ""
|
| 215 | """Regex to extract multimodal content. Empty string disables multimodal processing."""
|
| 216 |
|
| 217 |
|
| 218 | class DeterministicResponseAPIToolcallModel:
|
| 219 | """Deterministic test model using OpenAI Responses API format."""
|
| 220 |
|
| 221 | def __init__(self, **kwargs):
|
| 222 | """Initialize with a list of Response API output messages to return in sequence."""
|
| 223 | self.config = DeterministicResponseAPIToolcallModelConfig(**kwargs)
|
| 224 | self.current_index = -1
|
| 225 |
|
| 226 | def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
|
| 227 | self.current_index += 1
|
| 228 | output = self.config.outputs[self.current_index]
|
| 229 | if _process_test_actions(output.get("extra", {}).get("actions", [])):
|
| 230 | return self.query(messages, **kwargs)
|
| 231 | GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
|
| 232 | return output
|
| 233 |
|
| 234 | def format_message(self, **kwargs) -> dict:
|
| 235 | """Format message in Responses API format."""
|
| 236 | role = kwargs.get("role", "user")
|
| 237 | content = kwargs.get("content", "")
|
| 238 | extra = kwargs.get("extra")
|
| 239 | content_items = [{"type": "input_text", "text": content}] if isinstance(content, str) else content
|
| 240 | msg: dict = {"type": "message", "role": role, "content": content_items}
|
| 241 | if extra:
|
| 242 | msg["extra"] = extra
|
| 243 | return msg
|
| 244 |
|
| 245 | def format_observation_messages(
|
| 246 | self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
| 247 | ) -> list[dict]:
|
| 248 | """Format execution outputs into function_call_output messages."""
|
| 249 | actions = message.get("extra", {}).get("actions", [])
|
| 250 | return format_response_api_observation_messages(
|
| 251 | actions=actions,
|
| 252 | outputs=outputs,
|
| 253 | observation_template=self.config.observation_template,
|
| 254 | template_vars=template_vars,
|
| 255 | multimodal_regex=self.config.multimodal_regex,
|
| 256 | )
|
| 257 |
|
| 258 | def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
| 259 | return self.config.model_dump()
|
| 260 |
|
| 261 | def serialize(self) -> dict:
|
| 262 | return {
|
| 263 | "info": {
|
| 264 | "config": {
|
| 265 | "model": self.config.model_dump(mode="json"),
|
| 266 | "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
| 267 | },
|
| 268 | }
|
| 269 | }
|
| 270 |
|