| 1 | #!/usr/bin/env python3
|
| 2 |
|
| 3 | """Run mini-SWE-agent on SWE-bench instances in batch mode."""
|
| 4 | # Read this first: https://mini-swe-agent.com/latest/usage/swebench/ (usage docs)
|
| 5 |
|
| 6 | import concurrent.futures
|
| 7 | import json
|
| 8 | import random
|
| 9 | import re
|
| 10 | import threading
|
| 11 | import time
|
| 12 | import traceback
|
| 13 | from pathlib import Path
|
| 14 |
|
| 15 | import typer
|
| 16 | from jinja2 import StrictUndefined, Template
|
| 17 | from rich.live import Live
|
| 18 |
|
| 19 | from minisweagent import Environment
|
| 20 | from minisweagent.agents.default import DefaultAgent
|
| 21 | from minisweagent.config import builtin_config_dir, get_config_from_spec
|
| 22 | from minisweagent.environments import get_environment
|
| 23 | from minisweagent.models import get_model
|
| 24 | from minisweagent.run.benchmarks.utils.batch_progress import RunBatchProgressManager
|
| 25 | from minisweagent.utils.log import add_file_handler, logger
|
| 26 | from minisweagent.utils.serialize import UNSET, recursive_merge
|
| 27 |
|
| 28 | _HELP_TEXT = """Run mini-SWE-agent on SWEBench instances.
|
| 29 |
|
| 30 | [not dim]
|
| 31 | More information about the usage: [bold green]https://mini-swe-agent.com/latest/usage/swebench/[/bold green]
|
| 32 | [/not dim]
|
| 33 | """
|
| 34 |
|
| 35 | _CONFIG_SPEC_HELP_TEXT = """Path to config files, filenames, or key-value pairs.
|
| 36 |
|
| 37 | [bold red]IMPORTANT:[/bold red] [red]If you set this option, the default config file will not be used.[/red]
|
| 38 | So you need to explicitly set it e.g., with [bold green]-c swebench.yaml <other options>[/bold green]
|
| 39 |
|
| 40 | Multiple configs will be recursively merged.
|
| 41 |
|
| 42 | Examples:
|
| 43 |
|
| 44 | [bold red]-c model.model_kwargs.temperature=0[/bold red] [red]You forgot to add the default config file! See above.[/red]
|
| 45 |
|
| 46 | [bold green]-c swebench.yaml -c model.model_kwargs.temperature=0.5[/bold green]
|
| 47 |
|
| 48 | [bold green]-c swebench.yaml -c agent.max_iterations=50[/bold green]
|
| 49 | """
|
| 50 |
|
| 51 | DEFAULT_CONFIG_FILE = builtin_config_dir / "benchmarks" / "swebench.yaml"
|
| 52 |
|
| 53 | DATASET_MAPPING = {
|
| 54 | "full": "princeton-nlp/SWE-Bench",
|
| 55 | "verified": "princeton-nlp/SWE-Bench_Verified",
|
| 56 | "lite": "princeton-nlp/SWE-Bench_Lite",
|
| 57 | "multimodal": "princeton-nlp/SWE-Bench_Multimodal",
|
| 58 | "multilingual": "swe-bench/SWE-Bench_Multilingual",
|
| 59 | "smith": "SWE-bench/SWE-smith",
|
| 60 | "_test": "klieret/swe-bench-dummy-test-dataset",
|
| 61 | }
|
| 62 |
|
| 63 | app = typer.Typer(rich_markup_mode="rich", add_completion=False)
|
| 64 | _OUTPUT_FILE_LOCK = threading.Lock()
|
| 65 |
|
| 66 |
|
| 67 | class ProgressTrackingAgent(DefaultAgent):
|
| 68 | """Simple wrapper around DefaultAgent that provides progress updates."""
|
| 69 |
|
| 70 | def __init__(self, *args, progress_manager: RunBatchProgressManager, instance_id: str = "", **kwargs):
|
| 71 | super().__init__(*args, **kwargs)
|
| 72 | self.progress_manager: RunBatchProgressManager = progress_manager
|
| 73 | self.instance_id = instance_id
|
| 74 |
|
| 75 | def step(self) -> dict:
|
| 76 | """Override step to provide progress updates."""
|
| 77 | self.progress_manager.update_instance_status(self.instance_id, f"Step {self.n_calls + 1:3d} (${self.cost:.2f})")
|
| 78 | return super().step()
|
| 79 |
|
| 80 |
|
| 81 | def get_swebench_docker_image_name(instance: dict) -> str:
|
| 82 | """Get the image name for a SWEBench instance."""
|
| 83 | image_name = instance.get("image_name", None)
|
| 84 | if image_name is None:
|
| 85 | # Docker doesn't allow double underscore, so we replace them with a magic token
|
| 86 | iid = instance["instance_id"]
|
| 87 | id_docker_compatible = iid.replace("__", "_1776_")
|
| 88 | image_name = f"docker.io/swebench/sweb.eval.x86_64.{id_docker_compatible}:latest".lower()
|
| 89 | return image_name
|
| 90 |
|
| 91 |
|
| 92 | def get_sb_environment(config: dict, instance: dict) -> Environment:
|
| 93 | env_config = config.setdefault("environment", {})
|
| 94 | env_config["environment_class"] = env_config.get("environment_class", "docker")
|
| 95 | image_name = get_swebench_docker_image_name(instance)
|
| 96 | if env_config["environment_class"] in ["docker", "swerex_modal"]:
|
| 97 | env_config["image"] = image_name
|
| 98 | elif env_config["environment_class"] == "singularity":
|
| 99 | env_config["image"] = "docker://" + image_name
|
| 100 | env = get_environment(env_config)
|
| 101 | if startup_command := config.get("run", {}).get("env_startup_command"):
|
| 102 | startup_command = Template(startup_command, undefined=StrictUndefined).render(**instance)
|
| 103 | out = env.execute(startup_command)
|
| 104 | if out["returncode"] != 0:
|
| 105 | raise RuntimeError(f"Error executing startup command: {out}")
|
| 106 | return env
|
| 107 |
|
| 108 |
|
| 109 | def update_preds_file(output_path: Path, instance_id: str, model_name: str, result: str):
|
| 110 | """Update the output JSON file with results from a single instance."""
|
| 111 | with _OUTPUT_FILE_LOCK:
|
| 112 | output_data = {}
|
| 113 | if output_path.exists():
|
| 114 | output_data = json.loads(output_path.read_text())
|
| 115 | output_data[instance_id] = {
|
| 116 | "model_name_or_path": model_name,
|
| 117 | "instance_id": instance_id,
|
| 118 | "model_patch": result,
|
| 119 | }
|
| 120 | output_path.write_text(json.dumps(output_data, indent=2))
|
| 121 |
|
| 122 |
|
| 123 | def remove_from_preds_file(output_path: Path, instance_id: str):
|
| 124 | """Remove an instance from the predictions file."""
|
| 125 | if not output_path.exists():
|
| 126 | return
|
| 127 | with _OUTPUT_FILE_LOCK:
|
| 128 | output_data = json.loads(output_path.read_text())
|
| 129 | if instance_id in output_data:
|
| 130 | del output_data[instance_id]
|
| 131 | output_path.write_text(json.dumps(output_data, indent=2))
|
| 132 |
|
| 133 |
|
| 134 | def process_instance(
|
| 135 | instance: dict,
|
| 136 | output_dir: Path,
|
| 137 | config: dict,
|
| 138 | progress_manager: RunBatchProgressManager,
|
| 139 | ) -> None:
|
| 140 | """Process a single SWEBench instance."""
|
| 141 | instance_id = instance["instance_id"]
|
| 142 | instance_dir = output_dir / instance_id
|
| 143 | # avoid inconsistent state if something here fails and there's leftover previous files
|
| 144 | remove_from_preds_file(output_dir / "preds.json", instance_id)
|
| 145 | (instance_dir / f"{instance_id}.traj.json").unlink(missing_ok=True)
|
| 146 | model = get_model(config=config.get("model", {}))
|
| 147 | task = instance["problem_statement"]
|
| 148 |
|
| 149 | progress_manager.on_instance_start(instance_id)
|
| 150 | progress_manager.update_instance_status(instance_id, "Pulling/starting docker")
|
| 151 |
|
| 152 | agent = None
|
| 153 | exit_status = None
|
| 154 | result = None
|
| 155 | extra_info = {}
|
| 156 |
|
| 157 | try:
|
| 158 | env = get_sb_environment(config, instance)
|
| 159 | agent = ProgressTrackingAgent(
|
| 160 | model,
|
| 161 | env,
|
| 162 | progress_manager=progress_manager,
|
| 163 | instance_id=instance_id,
|
| 164 | **config.get("agent", {}),
|
| 165 | )
|
| 166 | info = agent.run(task)
|
| 167 | exit_status = info.get("exit_status")
|
| 168 | result = info.get("submission")
|
| 169 | except Exception as e:
|
| 170 | logger.error(f"Error processing instance {instance_id}: {e}", exc_info=True)
|
| 171 | exit_status, result = type(e).__name__, ""
|
| 172 | extra_info = {"traceback": traceback.format_exc(), "exception_str": str(e)}
|
| 173 | finally:
|
| 174 | if agent is not None:
|
| 175 | traj_path = instance_dir / f"{instance_id}.traj.json"
|
| 176 | agent.save(
|
| 177 | traj_path,
|
| 178 | {
|
| 179 | "info": {
|
| 180 | "exit_status": exit_status,
|
| 181 | "submission": result,
|
| 182 | **extra_info,
|
| 183 | },
|
| 184 | "instance_id": instance_id,
|
| 185 | },
|
| 186 | )
|
| 187 | logger.info(f"Saved trajectory to '{traj_path}'")
|
| 188 | update_preds_file(output_dir / "preds.json", instance_id, model.config.model_name, result)
|
| 189 | progress_manager.on_instance_end(instance_id, exit_status)
|
| 190 |
|
| 191 |
|
| 192 | def filter_instances(
|
| 193 | instances: list[dict], *, filter_spec: str, slice_spec: str = "", shuffle: bool = False
|
| 194 | ) -> list[dict]:
|
| 195 | """Filter and slice a list of SWEBench instances."""
|
| 196 | if shuffle:
|
| 197 | instances = sorted(instances.copy(), key=lambda x: x["instance_id"])
|
| 198 | random.seed(42)
|
| 199 | random.shuffle(instances)
|
| 200 | before_filter = len(instances)
|
| 201 | instances = [instance for instance in instances if re.match(filter_spec, instance["instance_id"])]
|
| 202 | if (after_filter := len(instances)) != before_filter:
|
| 203 | logger.info(f"Instance filter: {before_filter} -> {after_filter} instances")
|
| 204 | if slice_spec:
|
| 205 | values = [int(x) if x else None for x in slice_spec.split(":")]
|
| 206 | instances = instances[slice(*values)]
|
| 207 | if (after_slice := len(instances)) != before_filter:
|
| 208 | logger.info(f"Instance slice: {before_filter} -> {after_slice} instances")
|
| 209 | return instances
|
| 210 |
|
| 211 |
|
| 212 | # fmt: off
|
| 213 | @app.command(help=_HELP_TEXT)
|
| 214 | def main(
|
| 215 | subset: str = typer.Option("lite", "--subset", help="SWEBench subset to use or path to a dataset", rich_help_panel="Data selection"),
|
| 216 | split: str = typer.Option("dev", "--split", help="Dataset split", rich_help_panel="Data selection"),
|
| 217 | slice_spec: str = typer.Option("", "--slice", help="Slice specification (e.g., '0:5' for first 5 instances)", rich_help_panel="Data selection"),
|
| 218 | filter_spec: str = typer.Option("", "--filter", help="Filter instance IDs by regex", rich_help_panel="Data selection"),
|
| 219 | shuffle: bool = typer.Option(False, "--shuffle", help="Shuffle instances", rich_help_panel="Data selection"),
|
| 220 | output: str = typer.Option("", "-o", "--output", help="Output directory", rich_help_panel="Basic"),
|
| 221 | workers: int = typer.Option(1, "-w", "--workers", help="Number of worker threads for parallel processing", rich_help_panel="Basic"),
|
| 222 | model: str | None = typer.Option(None, "-m", "--model", help="Model to use", rich_help_panel="Basic"),
|
| 223 | model_class: str | None = typer.Option(None, "--model-class", help="Model class to use (e.g., 'anthropic' or 'minisweagent.models.anthropic.AnthropicModel')", rich_help_panel="Advanced"),
|
| 224 | redo_existing: bool = typer.Option(False, "--redo-existing", help="Redo existing instances", rich_help_panel="Data selection"),
|
| 225 | config_spec: list[str] = typer.Option([str(DEFAULT_CONFIG_FILE)], "-c", "--config", help=_CONFIG_SPEC_HELP_TEXT, rich_help_panel="Basic"),
|
| 226 | environment_class: str | None = typer.Option(None, "--environment-class", help="Environment type to use. Recommended are docker or singularity", rich_help_panel="Advanced"),
|
| 227 | ) -> None:
|
| 228 | # fmt: on
|
| 229 | output_path = Path(output)
|
| 230 | output_path.mkdir(parents=True, exist_ok=True)
|
| 231 | logger.info(f"Results will be saved to {output_path}")
|
| 232 | add_file_handler(output_path / "minisweagent.log")
|
| 233 |
|
| 234 | from datasets import load_dataset
|
| 235 |
|
| 236 | dataset_path = DATASET_MAPPING.get(subset, subset)
|
| 237 | logger.info(f"Loading dataset {dataset_path}, split {split}...")
|
| 238 | instances = list(load_dataset(dataset_path, split=split))
|
| 239 |
|
| 240 | instances = filter_instances(instances, filter_spec=filter_spec, slice_spec=slice_spec, shuffle=shuffle)
|
| 241 | if not redo_existing and (output_path / "preds.json").exists():
|
| 242 | existing_instances = list(json.loads((output_path / "preds.json").read_text()).keys())
|
| 243 | logger.info(f"Skipping {len(existing_instances)} existing instances")
|
| 244 | instances = [instance for instance in instances if instance["instance_id"] not in existing_instances]
|
| 245 | logger.info(f"Running on {len(instances)} instances...")
|
| 246 |
|
| 247 | logger.info(f"Building agent config from specs: {config_spec}")
|
| 248 | configs = [get_config_from_spec(spec) for spec in config_spec]
|
| 249 | configs.append({
|
| 250 | "environment": {"environment_class": environment_class or UNSET},
|
| 251 | "model": {"model_name": model or UNSET, "model_class": model_class or UNSET},
|
| 252 | })
|
| 253 | config = recursive_merge(*configs)
|
| 254 |
|
| 255 | progress_manager = RunBatchProgressManager(len(instances), output_path / f"exit_statuses_{time.time()}.yaml")
|
| 256 |
|
| 257 | def process_futures(futures: dict[concurrent.futures.Future, str]):
|
| 258 | for future in concurrent.futures.as_completed(futures):
|
| 259 | try:
|
| 260 | future.result()
|
| 261 | except concurrent.futures.CancelledError:
|
| 262 | pass
|
| 263 | except Exception as e:
|
| 264 | instance_id = futures[future]
|
| 265 | logger.error(f"Error in future for instance {instance_id}: {e}", exc_info=True)
|
| 266 | progress_manager.on_uncaught_exception(instance_id, e)
|
| 267 |
|
| 268 | with Live(progress_manager.render_group, refresh_per_second=4):
|
| 269 | with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
| 270 | futures = {
|
| 271 | executor.submit(process_instance, instance, output_path, config, progress_manager): instance[
|
| 272 | "instance_id"
|
| 273 | ]
|
| 274 | for instance in instances
|
| 275 | }
|
| 276 | try:
|
| 277 | process_futures(futures)
|
| 278 | except KeyboardInterrupt:
|
| 279 | logger.info("Cancelling all pending jobs. Press ^C again to exit immediately.")
|
| 280 | for future in futures:
|
| 281 | if not future.running() and not future.done():
|
| 282 | future.cancel()
|
| 283 | process_futures(futures)
|
| 284 |
|
| 285 |
|
| 286 | if __name__ == "__main__":
|
| 287 | app()
|
| 288 |
|