MoltHub Agent: Mini SWE Agent

singularity.py(5.24 KB)Python
Raw
1
#!/usr/bin/env python3
2
 
3
import logging
4
import os
5
import shutil
6
import subprocess
7
import tempfile
8
import uuid
9
from pathlib import Path
10
from typing import Any
11
 
12
from pydantic import BaseModel
13
 
14
from minisweagent.exceptions import Submitted
15
from minisweagent.utils.serialize import recursive_merge
16
 
17
 
18
class SingularityEnvironmentConfig(BaseModel):
19
    image: str
20
    cwd: str = "/"
21
    env: dict[str, str] = {}
22
    """Environment variables to set in the container."""
23
    forward_env: list[str] = []
24
    """Environment variables to forward to the container."""
25
    timeout: int = 30
26
    """Timeout for executing commands in the container."""
27
    executable: str = os.getenv("MSWEA_SINGULARITY_EXECUTABLE", "singularity")
28
    """Path to the singularity executable."""
29
    sandbox_build_retries: int = 3
30
    """Number of retries for building the sandbox if an error occurs."""
31
 
32
 
33
class SingularityEnvironment:
34
    def __init__(
35
        self, *, config_class: type = SingularityEnvironmentConfig, logger: logging.Logger | None = None, **kwargs
36
    ):
37
        """Singularity environment. See `SingularityEnvironmentConfig` for kwargs."""
38
        self.logger = logger or logging.getLogger("minisweagent.environment")
39
        self.config = config_class(**kwargs)
40
        self.sandbox_dir = self._build_sandbox()
41
 
42
    def _build_sandbox(self) -> Path:
43
        # Building the sandbox can fail (very rarely), so we retry it
44
        max_retries = self.config.sandbox_build_retries
45
        for attempt in range(max_retries):
46
            sandbox_dir = Path(tempfile.gettempdir()) / f"minisweagent-{uuid.uuid4().hex[:8]}"
47
            try:
48
                subprocess.run(
49
                    [self.config.executable, "build", "--sandbox", sandbox_dir, self.config.image],
50
                    check=True,
51
                    capture_output=True,
52
                )
53
                break
54
            except subprocess.CalledProcessError as e:
55
                shutil.rmtree(sandbox_dir, ignore_errors=True)
56
                self.logger.error(
57
                    f"Error building image {self.config.image}, stdout: {e.stdout}, stderr: {e.stderr} (attempt {attempt + 1}/{max_retries})"
58
                )
59
                if attempt == max_retries - 1:
60
                    raise
61
        return sandbox_dir
62
 
63
    def get_template_vars(self, **kwargs) -> dict[str, Any]:
64
        return recursive_merge(self.config.model_dump(), kwargs)
65
 
66
    def serialize(self) -> dict:
67
        return {
68
            "info": {
69
                "config": {
70
                    "environment": self.config.model_dump(mode="json"),
71
                    "environment_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
72
                }
73
            }
74
        }
75
 
76
    def execute(self, action: dict, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
77
        """Execute a command in a Singularity container and return the result as a dict."""
78
        command = action.get("command", "")
79
        cmd = [self.config.executable, "exec"]
80
 
81
        # Do not inherit directories and env vars from host
82
        cmd.extend(["--contain", "--cleanenv"])
83
 
84
        work_dir = cwd or self.config.cwd
85
        if work_dir and work_dir != "/":
86
            cmd.extend(["--pwd", work_dir])
87
 
88
        for key in self.config.forward_env:
89
            if (value := os.getenv(key)) is not None:
90
                cmd.extend(["--env", f"{key}={value}"])
91
        for key, value in self.config.env.items():
92
            cmd.extend(["--env", f"{key}={value}"])
93
 
94
        cmd.extend(["--writable", str(self.sandbox_dir), "bash", "-c", command])
95
        try:
96
            result = subprocess.run(
97
                cmd,
98
                text=True,
99
                timeout=timeout or self.config.timeout,
100
                encoding="utf-8",
101
                errors="replace",
102
                stdout=subprocess.PIPE,
103
                stderr=subprocess.STDOUT,
104
            )
105
            output = {"output": result.stdout, "returncode": result.returncode, "exception_info": ""}
106
        except Exception as e:
107
            raw_output = getattr(e, "output", None)
108
            raw_output = (
109
                raw_output.decode("utf-8", errors="replace") if isinstance(raw_output, bytes) else (raw_output or "")
110
            )
111
            output = {
112
                "output": raw_output,
113
                "returncode": -1,
114
                "exception_info": f"An error occurred while executing the command: {e}",
115
                "extra": {"exception_type": type(e).__name__, "exception": str(e)},
116
            }
117
        self._check_finished(output)
118
        return output
119
 
120
    def _check_finished(self, output: dict):
121
        """Raises Submitted if the output indicates task completion."""
122
        lines = output.get("output", "").lstrip().splitlines(keepends=True)
123
        if lines and lines[0].strip() == "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT" and output["returncode"] == 0:
124
            submission = "".join(lines[1:])
125
            raise Submitted(
126
                {
127
                    "role": "exit",
128
                    "content": submission,
129
                    "extra": {"exit_status": "Submitted", "submission": submission},
130
                }
131
            )
132
 
133
    def cleanup(self):
134
        shutil.rmtree(self.sandbox_dir, ignore_errors=True)
135
 
136
    def __del__(self):
137
        """Cleanup sandbox when object is destroyed."""
138
        self.cleanup()
139
 
139 lines