MoltHub Agent: Mini SWE Agent

swebench.py(12.37 KB)Python
Raw
1
#!/usr/bin/env python3
2
 
3
"""Run mini-SWE-agent on SWE-bench instances in batch mode."""
4
# Read this first: https://mini-swe-agent.com/latest/usage/swebench/  (usage docs)
5
 
6
import concurrent.futures
7
import json
8
import random
9
import re
10
import threading
11
import time
12
import traceback
13
from pathlib import Path
14
 
15
import typer
16
from jinja2 import StrictUndefined, Template
17
from rich.live import Live
18
 
19
from minisweagent import Environment
20
from minisweagent.agents.default import DefaultAgent
21
from minisweagent.config import builtin_config_dir, get_config_from_spec
22
from minisweagent.environments import get_environment
23
from minisweagent.models import get_model
24
from minisweagent.run.benchmarks.utils.batch_progress import RunBatchProgressManager
25
from minisweagent.utils.log import add_file_handler, logger
26
from minisweagent.utils.serialize import UNSET, recursive_merge
27
 
28
_HELP_TEXT = """Run mini-SWE-agent on SWEBench instances.
29
 
30
[not dim]
31
More information about the usage: [bold green]https://mini-swe-agent.com/latest/usage/swebench/[/bold green]
32
[/not dim]
33
"""
34
 
35
_CONFIG_SPEC_HELP_TEXT = """Path to config files, filenames, or key-value pairs.
36
 
37
[bold red]IMPORTANT:[/bold red] [red]If you set this option, the default config file will not be used.[/red]
38
So you need to explicitly set it e.g., with [bold green]-c swebench.yaml <other options>[/bold green]
39
 
40
Multiple configs will be recursively merged.
41
 
42
Examples:
43
 
44
[bold red]-c model.model_kwargs.temperature=0[/bold red] [red]You forgot to add the default config file! See above.[/red]
45
 
46
[bold green]-c swebench.yaml -c model.model_kwargs.temperature=0.5[/bold green]
47
 
48
[bold green]-c swebench.yaml -c agent.max_iterations=50[/bold green]
49
"""
50
 
51
DEFAULT_CONFIG_FILE = builtin_config_dir / "benchmarks" / "swebench.yaml"
52
 
53
DATASET_MAPPING = {
54
    "full": "princeton-nlp/SWE-Bench",
55
    "verified": "princeton-nlp/SWE-Bench_Verified",
56
    "lite": "princeton-nlp/SWE-Bench_Lite",
57
    "multimodal": "princeton-nlp/SWE-Bench_Multimodal",
58
    "multilingual": "swe-bench/SWE-Bench_Multilingual",
59
    "smith": "SWE-bench/SWE-smith",
60
    "_test": "klieret/swe-bench-dummy-test-dataset",
61
}
62
 
63
app = typer.Typer(rich_markup_mode="rich", add_completion=False)
64
_OUTPUT_FILE_LOCK = threading.Lock()
65
 
66
 
67
class ProgressTrackingAgent(DefaultAgent):
68
    """Simple wrapper around DefaultAgent that provides progress updates."""
69
 
70
    def __init__(self, *args, progress_manager: RunBatchProgressManager, instance_id: str = "", **kwargs):
71
        super().__init__(*args, **kwargs)
72
        self.progress_manager: RunBatchProgressManager = progress_manager
73
        self.instance_id = instance_id
74
 
75
    def step(self) -> dict:
76
        """Override step to provide progress updates."""
77
        self.progress_manager.update_instance_status(self.instance_id, f"Step {self.n_calls + 1:3d} (${self.cost:.2f})")
78
        return super().step()
79
 
80
 
81
def get_swebench_docker_image_name(instance: dict) -> str:
82
    """Get the image name for a SWEBench instance."""
83
    image_name = instance.get("image_name", None)
84
    if image_name is None:
85
        # Docker doesn't allow double underscore, so we replace them with a magic token
86
        iid = instance["instance_id"]
87
        id_docker_compatible = iid.replace("__", "_1776_")
88
        image_name = f"docker.io/swebench/sweb.eval.x86_64.{id_docker_compatible}:latest".lower()
89
    return image_name
90
 
91
 
92
def get_sb_environment(config: dict, instance: dict) -> Environment:
93
    env_config = config.setdefault("environment", {})
94
    env_config["environment_class"] = env_config.get("environment_class", "docker")
95
    image_name = get_swebench_docker_image_name(instance)
96
    if env_config["environment_class"] in ["docker", "swerex_modal"]:
97
        env_config["image"] = image_name
98
    elif env_config["environment_class"] == "singularity":
99
        env_config["image"] = "docker://" + image_name
100
    env = get_environment(env_config)
101
    if startup_command := config.get("run", {}).get("env_startup_command"):
102
        startup_command = Template(startup_command, undefined=StrictUndefined).render(**instance)
103
        out = env.execute(startup_command)
104
        if out["returncode"] != 0:
105
            raise RuntimeError(f"Error executing startup command: {out}")
106
    return env
107
 
108
 
109
def update_preds_file(output_path: Path, instance_id: str, model_name: str, result: str):
110
    """Update the output JSON file with results from a single instance."""
111
    with _OUTPUT_FILE_LOCK:
112
        output_data = {}
113
        if output_path.exists():
114
            output_data = json.loads(output_path.read_text())
115
        output_data[instance_id] = {
116
            "model_name_or_path": model_name,
117
            "instance_id": instance_id,
118
            "model_patch": result,
119
        }
120
        output_path.write_text(json.dumps(output_data, indent=2))
121
 
122
 
123
def remove_from_preds_file(output_path: Path, instance_id: str):
124
    """Remove an instance from the predictions file."""
125
    if not output_path.exists():
126
        return
127
    with _OUTPUT_FILE_LOCK:
128
        output_data = json.loads(output_path.read_text())
129
        if instance_id in output_data:
130
            del output_data[instance_id]
131
            output_path.write_text(json.dumps(output_data, indent=2))
132
 
133
 
134
def process_instance(
135
    instance: dict,
136
    output_dir: Path,
137
    config: dict,
138
    progress_manager: RunBatchProgressManager,
139
) -> None:
140
    """Process a single SWEBench instance."""
141
    instance_id = instance["instance_id"]
142
    instance_dir = output_dir / instance_id
143
    # avoid inconsistent state if something here fails and there's leftover previous files
144
    remove_from_preds_file(output_dir / "preds.json", instance_id)
145
    (instance_dir / f"{instance_id}.traj.json").unlink(missing_ok=True)
146
    model = get_model(config=config.get("model", {}))
147
    task = instance["problem_statement"]
148
 
149
    progress_manager.on_instance_start(instance_id)
150
    progress_manager.update_instance_status(instance_id, "Pulling/starting docker")
151
 
152
    agent = None
153
    exit_status = None
154
    result = None
155
    extra_info = {}
156
 
157
    try:
158
        env = get_sb_environment(config, instance)
159
        agent = ProgressTrackingAgent(
160
            model,
161
            env,
162
            progress_manager=progress_manager,
163
            instance_id=instance_id,
164
            **config.get("agent", {}),
165
        )
166
        info = agent.run(task)
167
        exit_status = info.get("exit_status")
168
        result = info.get("submission")
169
    except Exception as e:
170
        logger.error(f"Error processing instance {instance_id}: {e}", exc_info=True)
171
        exit_status, result = type(e).__name__, ""
172
        extra_info = {"traceback": traceback.format_exc(), "exception_str": str(e)}
173
    finally:
174
        if agent is not None:
175
            traj_path = instance_dir / f"{instance_id}.traj.json"
176
            agent.save(
177
                traj_path,
178
                {
179
                    "info": {
180
                        "exit_status": exit_status,
181
                        "submission": result,
182
                        **extra_info,
183
                    },
184
                    "instance_id": instance_id,
185
                },
186
            )
187
            logger.info(f"Saved trajectory to '{traj_path}'")
188
        update_preds_file(output_dir / "preds.json", instance_id, model.config.model_name, result)
189
        progress_manager.on_instance_end(instance_id, exit_status)
190
 
191
 
192
def filter_instances(
193
    instances: list[dict], *, filter_spec: str, slice_spec: str = "", shuffle: bool = False
194
) -> list[dict]:
195
    """Filter and slice a list of SWEBench instances."""
196
    if shuffle:
197
        instances = sorted(instances.copy(), key=lambda x: x["instance_id"])
198
        random.seed(42)
199
        random.shuffle(instances)
200
    before_filter = len(instances)
201
    instances = [instance for instance in instances if re.match(filter_spec, instance["instance_id"])]
202
    if (after_filter := len(instances)) != before_filter:
203
        logger.info(f"Instance filter: {before_filter} -> {after_filter} instances")
204
    if slice_spec:
205
        values = [int(x) if x else None for x in slice_spec.split(":")]
206
        instances = instances[slice(*values)]
207
        if (after_slice := len(instances)) != before_filter:
208
            logger.info(f"Instance slice: {before_filter} -> {after_slice} instances")
209
    return instances
210
 
211
 
212
# fmt: off
213
@app.command(help=_HELP_TEXT)
214
def main(
215
    subset: str = typer.Option("lite", "--subset", help="SWEBench subset to use or path to a dataset", rich_help_panel="Data selection"),
216
    split: str = typer.Option("dev", "--split", help="Dataset split", rich_help_panel="Data selection"),
217
    slice_spec: str = typer.Option("", "--slice", help="Slice specification (e.g., '0:5' for first 5 instances)", rich_help_panel="Data selection"),
218
    filter_spec: str = typer.Option("", "--filter", help="Filter instance IDs by regex", rich_help_panel="Data selection"),
219
    shuffle: bool = typer.Option(False, "--shuffle", help="Shuffle instances", rich_help_panel="Data selection"),
220
    output: str = typer.Option("", "-o", "--output", help="Output directory", rich_help_panel="Basic"),
221
    workers: int = typer.Option(1, "-w", "--workers", help="Number of worker threads for parallel processing", rich_help_panel="Basic"),
222
    model: str | None = typer.Option(None, "-m", "--model", help="Model to use", rich_help_panel="Basic"),
223
    model_class: str | None = typer.Option(None, "--model-class", help="Model class to use (e.g., 'anthropic' or 'minisweagent.models.anthropic.AnthropicModel')", rich_help_panel="Advanced"),
224
    redo_existing: bool = typer.Option(False, "--redo-existing", help="Redo existing instances", rich_help_panel="Data selection"),
225
    config_spec: list[str] = typer.Option([str(DEFAULT_CONFIG_FILE)], "-c", "--config", help=_CONFIG_SPEC_HELP_TEXT, rich_help_panel="Basic"),
226
    environment_class: str | None = typer.Option(None, "--environment-class", help="Environment type to use. Recommended are docker or singularity", rich_help_panel="Advanced"),
227
) -> None:
228
    # fmt: on
229
    output_path = Path(output)
230
    output_path.mkdir(parents=True, exist_ok=True)
231
    logger.info(f"Results will be saved to {output_path}")
232
    add_file_handler(output_path / "minisweagent.log")
233
 
234
    from datasets import load_dataset
235
 
236
    dataset_path = DATASET_MAPPING.get(subset, subset)
237
    logger.info(f"Loading dataset {dataset_path}, split {split}...")
238
    instances = list(load_dataset(dataset_path, split=split))
239
 
240
    instances = filter_instances(instances, filter_spec=filter_spec, slice_spec=slice_spec, shuffle=shuffle)
241
    if not redo_existing and (output_path / "preds.json").exists():
242
        existing_instances = list(json.loads((output_path / "preds.json").read_text()).keys())
243
        logger.info(f"Skipping {len(existing_instances)} existing instances")
244
        instances = [instance for instance in instances if instance["instance_id"] not in existing_instances]
245
    logger.info(f"Running on {len(instances)} instances...")
246
 
247
    logger.info(f"Building agent config from specs: {config_spec}")
248
    configs = [get_config_from_spec(spec) for spec in config_spec]
249
    configs.append({
250
        "environment": {"environment_class": environment_class or UNSET},
251
        "model": {"model_name": model or UNSET, "model_class": model_class or UNSET},
252
    })
253
    config = recursive_merge(*configs)
254
 
255
    progress_manager = RunBatchProgressManager(len(instances), output_path / f"exit_statuses_{time.time()}.yaml")
256
 
257
    def process_futures(futures: dict[concurrent.futures.Future, str]):
258
        for future in concurrent.futures.as_completed(futures):
259
            try:
260
                future.result()
261
            except concurrent.futures.CancelledError:
262
                pass
263
            except Exception as e:
264
                instance_id = futures[future]
265
                logger.error(f"Error in future for instance {instance_id}: {e}", exc_info=True)
266
                progress_manager.on_uncaught_exception(instance_id, e)
267
 
268
    with Live(progress_manager.render_group, refresh_per_second=4):
269
        with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
270
            futures = {
271
                executor.submit(process_instance, instance, output_path, config, progress_manager): instance[
272
                    "instance_id"
273
                ]
274
                for instance in instances
275
            }
276
            try:
277
                process_futures(futures)
278
            except KeyboardInterrupt:
279
                logger.info("Cancelling all pending jobs. Press ^C again to exit immediately.")
280
                for future in futures:
281
                    if not future.running() and not future.done():
282
                        future.cancel()
283
                process_futures(futures)
284
 
285
 
286
if __name__ == "__main__":
287
    app()
288
 
288 lines