| 1 | import asyncio
|
| 2 | from typing import Any
|
| 3 |
|
| 4 | from pydantic import BaseModel
|
| 5 | from swerex.deployment.docker import DockerDeployment
|
| 6 | from swerex.runtime.abstract import Command as RexCommand
|
| 7 |
|
| 8 | from minisweagent.exceptions import Submitted
|
| 9 | from minisweagent.utils.serialize import recursive_merge
|
| 10 |
|
| 11 |
|
| 12 | class SwerexDockerEnvironmentConfig(BaseModel):
|
| 13 | image: str
|
| 14 | cwd: str = "/"
|
| 15 | """Working directory in which to execute commands."""
|
| 16 | timeout: int = 30
|
| 17 | """Timeout for executing commands in the container."""
|
| 18 | deployment_extra_kwargs: dict[str, Any] = {}
|
| 19 | """Extra kwargs to pass to DockerDeployment."""
|
| 20 |
|
| 21 |
|
| 22 | class SwerexDockerEnvironment:
|
| 23 | def __init__(self, **kwargs):
|
| 24 | """This class executes bash commands in a Docker container using SWE-ReX for sandboxing."""
|
| 25 | self.config = SwerexDockerEnvironmentConfig(**kwargs)
|
| 26 | self.deployment = DockerDeployment(image=self.config.image, **self.config.deployment_extra_kwargs)
|
| 27 | asyncio.run(self.deployment.start())
|
| 28 |
|
| 29 | def execute(self, action: dict, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
|
| 30 | """Execute a command in the environment and return the raw output."""
|
| 31 | command = action.get("command", "")
|
| 32 | try:
|
| 33 | result = asyncio.run(
|
| 34 | self.deployment.runtime.execute(
|
| 35 | RexCommand(
|
| 36 | command=command,
|
| 37 | shell=True,
|
| 38 | check=False,
|
| 39 | cwd=cwd or self.config.cwd,
|
| 40 | timeout=timeout or self.config.timeout,
|
| 41 | merge_output_streams=True,
|
| 42 | )
|
| 43 | )
|
| 44 | )
|
| 45 | output = {"output": result.stdout, "returncode": result.exit_code, "exception_info": ""}
|
| 46 | except Exception as e:
|
| 47 | output = {
|
| 48 | "output": str(e) if str(e) else "",
|
| 49 | "returncode": -1,
|
| 50 | "exception_info": f"An error occurred while executing the command: {e}",
|
| 51 | "extra": {"exception_type": type(e).__name__, "exception": str(e)},
|
| 52 | }
|
| 53 | self._check_finished(output)
|
| 54 | return output
|
| 55 |
|
| 56 | def _check_finished(self, output: dict):
|
| 57 | """Raises Submitted if the output indicates task completion."""
|
| 58 | lines = output.get("output", "").lstrip().splitlines(keepends=True)
|
| 59 | if lines and lines[0].strip() == "COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT" and output["returncode"] == 0:
|
| 60 | submission = "".join(lines[1:])
|
| 61 | raise Submitted(
|
| 62 | {
|
| 63 | "role": "exit",
|
| 64 | "content": submission,
|
| 65 | "extra": {"exit_status": "Submitted", "submission": submission},
|
| 66 | }
|
| 67 | )
|
| 68 |
|
| 69 | def get_template_vars(self, **kwargs) -> dict[str, Any]:
|
| 70 | return recursive_merge(self.config.model_dump(), kwargs)
|
| 71 |
|
| 72 | def serialize(self) -> dict:
|
| 73 | return {
|
| 74 | "info": {
|
| 75 | "config": {
|
| 76 | "environment": self.config.model_dump(mode="json"),
|
| 77 | "environment_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
|
| 78 | }
|
| 79 | }
|
| 80 | }
|
| 81 |
|