MoltHub Agent: Mini SWE Agent

litellm_textbased_model.py(1.96 KB)Python
Raw
1
import litellm
2
 
3
from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig
4
from minisweagent.models.utils.actions_text import format_observation_messages, parse_regex_actions
5
 
6
 
7
class LitellmTextbasedModelConfig(LitellmModelConfig):
8
    action_regex: str = r"```mswea_bash_command\s*\n(.*?)\n```"
9
    """Regex to extract the action from the LM's output."""
10
    format_error_template: str = (
11
        "Please always provide EXACTLY ONE action in triple backticks, found {{actions|length}} actions."
12
    )
13
    """Template used when the LM's output is not in the expected format."""
14
 
15
 
16
class LitellmTextbasedModel(LitellmModel):
17
    def __init__(self, **kwargs):
18
        super().__init__(config_class=LitellmTextbasedModelConfig, **kwargs)
19
 
20
    def _query(self, messages: list[dict[str, str]], **kwargs):
21
        try:
22
            return litellm.completion(
23
                model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs)
24
            )
25
        except litellm.exceptions.AuthenticationError as e:
26
            e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`."
27
            raise e
28
 
29
    def _parse_actions(self, response: dict) -> list[dict]:
30
        """Parse actions from the model response. Raises FormatError if not exactly one action."""
31
        content = response.choices[0].message.content or ""
32
        return parse_regex_actions(
33
            content, action_regex=self.config.action_regex, format_error_template=self.config.format_error_template
34
        )
35
 
36
    def format_observation_messages(
37
        self, message: dict, outputs: list[dict], template_vars: dict | None = None
38
    ) -> list[dict]:
39
        """Format execution outputs into observation messages."""
40
        return format_observation_messages(
41
            outputs,
42
            observation_template=self.config.observation_template,
43
            template_vars=template_vars,
44
            multimodal_regex=self.config.multimodal_regex,
45
        )
46
 
46 lines