| 1 | """
|
| 2 | This file provides:
|
| 3 |
|
| 4 | - Path settings for global config file & relative directories
|
| 5 | - Version numbering
|
| 6 | - Protocols for the core components of mini-swe-agent.
|
| 7 | By the magic of protocols & duck typing, you can pretty much ignore them,
|
| 8 | unless you want the static type checking.
|
| 9 | """
|
| 10 |
|
| 11 | __version__ = "2.0.0a2"
|
| 12 |
|
| 13 | import os
|
| 14 | from pathlib import Path
|
| 15 | from typing import Any, Protocol
|
| 16 |
|
| 17 | import dotenv
|
| 18 | from platformdirs import user_config_dir
|
| 19 | from rich.console import Console
|
| 20 |
|
| 21 | from minisweagent.utils.log import logger
|
| 22 |
|
| 23 | package_dir = Path(__file__).resolve().parent
|
| 24 |
|
| 25 |
|
| 26 | global_config_dir = Path(os.getenv("MSWEA_GLOBAL_CONFIG_DIR") or user_config_dir("mini-swe-agent"))
|
| 27 | global_config_dir.mkdir(parents=True, exist_ok=True)
|
| 28 | global_config_file = Path(global_config_dir) / ".env"
|
| 29 |
|
| 30 | if not os.getenv("MSWEA_SILENT_STARTUP"):
|
| 31 | Console().print(
|
| 32 | f"👋 This is [bold green]mini-swe-agent[/bold green] version [bold green]{__version__}[/bold green].\n"
|
| 33 | f"Check the [bold red]v2 migration guide[/] at [bold red]https://klieret.short.gy/mini-v2-migration[/]\n",
|
| 34 | f"Loading global config from [bold green]'{global_config_file}'[/bold green]",
|
| 35 | )
|
| 36 | dotenv.load_dotenv(dotenv_path=global_config_file)
|
| 37 |
|
| 38 |
|
| 39 | # === Protocols ===
|
| 40 | # You can ignore them unless you want static type checking.
|
| 41 |
|
| 42 |
|
| 43 | class Model(Protocol):
|
| 44 | """Protocol for language models."""
|
| 45 |
|
| 46 | config: Any
|
| 47 |
|
| 48 | def query(self, messages: list[dict[str, str]], **kwargs) -> dict: ...
|
| 49 |
|
| 50 | def format_message(self, **kwargs) -> dict: ...
|
| 51 |
|
| 52 | def format_observation_messages(
|
| 53 | self, message: dict, outputs: list[dict], template_vars: dict | None = None
|
| 54 | ) -> list[dict]: ...
|
| 55 |
|
| 56 | def get_template_vars(self, **kwargs) -> dict[str, Any]: ...
|
| 57 |
|
| 58 | def serialize(self) -> dict: ...
|
| 59 |
|
| 60 |
|
| 61 | class Environment(Protocol):
|
| 62 | """Protocol for execution environments."""
|
| 63 |
|
| 64 | config: Any
|
| 65 |
|
| 66 | def execute(self, action: dict, cwd: str = "") -> dict[str, Any]: ...
|
| 67 |
|
| 68 | def get_template_vars(self, **kwargs) -> dict[str, Any]: ...
|
| 69 |
|
| 70 | def serialize(self) -> dict: ...
|
| 71 |
|
| 72 |
|
| 73 | class Agent(Protocol):
|
| 74 | """Protocol for agents."""
|
| 75 |
|
| 76 | config: Any
|
| 77 |
|
| 78 | def run(self, task: str, **kwargs) -> dict: ...
|
| 79 |
|
| 80 | def save(self, path: Path | None, *extra_dicts) -> dict: ...
|
| 81 |
|
| 82 |
|
| 83 | __all__ = [
|
| 84 | "Agent",
|
| 85 | "Model",
|
| 86 | "Environment",
|
| 87 | "package_dir",
|
| 88 | "__version__",
|
| 89 | "global_config_file",
|
| 90 | "global_config_dir",
|
| 91 | "logger",
|
| 92 | ]
|
| 93 |
|