Skip to content

Benchmarks API

benchmarks.wargames_bench.WargamesBench

Run a policy against all WargamesBench scenarios and return results.

Parameters:

Name Type Description Default
config Optional[BenchConfig]

Benchmark configuration. Use :class:BenchConfig defaults for the canonical benchmark; override n_eval_episodes for quick CI runs.

None
Source code in benchmarks/wargames_bench.py
class WargamesBench:
    """Run a policy against all WargamesBench scenarios and return results.

    Parameters
    ----------
    config:
        Benchmark configuration.  Use :class:`BenchConfig` defaults for the
        canonical benchmark; override ``n_eval_episodes`` for quick CI runs.
    """

    def __init__(self, config: Optional[BenchConfig] = None) -> None:
        self.config = config or BenchConfig()

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def run(
        self,
        policy: Optional[Any] = None,
        *,
        label: Optional[str] = None,
    ) -> BenchSummary:
        """Evaluate *policy* on all benchmark scenarios.

        Parameters
        ----------
        policy:
            Any callable that accepts a flat ``np.ndarray`` observation and
            returns an action array; **or** an SB3-style object with a
            ``predict(obs, deterministic=True)`` method; **or** ``None`` to
            use the built-in scripted baseline.
        label:
            Override :attr:`BenchConfig.baseline_label` for the leaderboard.

        Returns
        -------
        :class:`BenchSummary`
        """
        cfg = self.config
        effective_label = label or cfg.baseline_label
        # Build a config copy that carries the effective label so all summary
        # rendering paths (str, write_markdown, to_leaderboard_row) are consistent.
        summary_cfg = dataclasses.replace(cfg, baseline_label=effective_label)
        n = min(cfg.n_scenarios, len(BENCH_SCENARIOS))
        scenarios = [BenchScenario.from_dict(d) for d in BENCH_SCENARIOS[:n]]

        results: List[BenchResult] = []

        for scenario in scenarios:
            env = self._make_env(scenario)
            t0 = time.perf_counter()
            episodes = self._evaluate(policy, env, scenario)
            elapsed = time.perf_counter() - t0
            stats = _aggregate_episodes(episodes)
            results.append(
                BenchResult(
                    scenario_name=scenario.name,
                    policy_label=effective_label,
                    elapsed_seconds=elapsed,
                    **stats,
                )
            )
            env.close()
            log.info(
                "Scenario %-25s  win_rate=%.1f%%  steps=%.0f",
                scenario.name,
                stats["win_rate"] * 100,
                stats["mean_steps"],
            )

        return BenchSummary(results=results, config=summary_cfg)

    # ------------------------------------------------------------------
    # Internal: environment factory
    # ------------------------------------------------------------------

    def _make_env(self, scenario: BenchScenario) -> Any:
        """Create an evaluation environment for *scenario*.

        Attempts to instantiate a real :class:`~envs.battalion_env.BattalionEnv`
        with procedurally generated terrain and fixed weather (when available).
        Falls back to :class:`_SyntheticEnv` when the real env cannot be
        constructed (e.g. missing optional dependencies in CI).

        .. note::
            :class:`~envs.battalion_env.BattalionEnv` is a 1v1 environment.
            The ``n_blue`` / ``n_red`` fields in a :class:`BenchScenario` have
            no effect on the real env; they only influence the entity-count of
            the :class:`_SyntheticEnv` fallback.  Weather *is* applied to the
            real env via :class:`~envs.sim.weather.WeatherConfig`.
        """
        try:
            from envs.battalion_env import BattalionEnv
            from envs.sim.terrain import TerrainMap
            from envs.sim.weather import WeatherCondition, WeatherConfig

            _WEATHER_MAP: Dict[str, WeatherCondition] = {
                "clear": WeatherCondition.CLEAR,
                "rain":  WeatherCondition.RAIN,
                "fog":   WeatherCondition.FOG,
                "snow":  WeatherCondition.SNOW,
            }
            fixed_condition = _WEATHER_MAP.get(scenario.weather.lower())
            weather_cfg = WeatherConfig(fixed_condition=fixed_condition)

            terrain = TerrainMap.generate_random(
                rng=np.random.default_rng(scenario.terrain_seed),
                width=scenario.map_width,
                height=scenario.map_height,
                rows=40,
                cols=40,
                num_hills=scenario.n_hills,
                num_forests=scenario.n_forests,
            )
            return BattalionEnv(
                terrain=terrain,
                randomize_terrain=False,
                map_width=terrain.width,
                map_height=terrain.height,
                enable_weather=True,
                weather_config=weather_cfg,
            )
        except Exception as exc:
            log.debug(
                "Could not build real BattalionEnv for scenario %r: %s — using synthetic.",
                scenario.name,
                exc,
                exc_info=True,
            )
        return _SyntheticEnv(
            n_entities=scenario.n_blue + scenario.n_red,
            seed=scenario.terrain_seed,
            ep_length=self.config.max_steps_per_episode,
        )

    # ------------------------------------------------------------------
    # Internal: episode runner
    # ------------------------------------------------------------------

    def _evaluate(
        self,
        policy: Optional[Any],
        env: Any,
        scenario: BenchScenario,
    ) -> List[Dict[str, Any]]:
        """Run ``config.n_eval_episodes`` episodes and return raw stats."""
        cfg = self.config
        use_predict = hasattr(policy, "predict")
        episodes: List[Dict[str, Any]] = []

        for ep_idx in range(cfg.n_eval_episodes):
            # Deterministic reset: base seed + episode offset for reproducibility.
            ep_seed = scenario.seed + ep_idx
            obs, _ = env.reset(seed=ep_seed)
            terminated = truncated = False
            steps = 0
            won: Optional[bool] = None

            while not (terminated or truncated) and steps < cfg.max_steps_per_episode:
                if use_predict:
                    action, _ = policy.predict(obs, deterministic=True)
                elif callable(policy):
                    action = policy(obs)
                else:
                    action = _scripted_action(obs)

                obs, _reward, terminated, truncated, info = env.step(action)
                steps += 1

            if isinstance(info, dict):
                if info.get("red_routed"):
                    won = True
                elif info.get("blue_routed"):
                    won = False

            episodes.append({"won": won, "steps": steps})

        return episodes

run(policy=None, *, label=None)

Evaluate policy on all benchmark scenarios.

Parameters:

Name Type Description Default
policy Optional[Any]

Any callable that accepts a flat np.ndarray observation and returns an action array; or an SB3-style object with a predict(obs, deterministic=True) method; or None to use the built-in scripted baseline.

None
label Optional[str]

Override :attr:BenchConfig.baseline_label for the leaderboard.

None

Returns:

Type Description
class:`BenchSummary`
Source code in benchmarks/wargames_bench.py
def run(
    self,
    policy: Optional[Any] = None,
    *,
    label: Optional[str] = None,
) -> BenchSummary:
    """Evaluate *policy* on all benchmark scenarios.

    Parameters
    ----------
    policy:
        Any callable that accepts a flat ``np.ndarray`` observation and
        returns an action array; **or** an SB3-style object with a
        ``predict(obs, deterministic=True)`` method; **or** ``None`` to
        use the built-in scripted baseline.
    label:
        Override :attr:`BenchConfig.baseline_label` for the leaderboard.

    Returns
    -------
    :class:`BenchSummary`
    """
    cfg = self.config
    effective_label = label or cfg.baseline_label
    # Build a config copy that carries the effective label so all summary
    # rendering paths (str, write_markdown, to_leaderboard_row) are consistent.
    summary_cfg = dataclasses.replace(cfg, baseline_label=effective_label)
    n = min(cfg.n_scenarios, len(BENCH_SCENARIOS))
    scenarios = [BenchScenario.from_dict(d) for d in BENCH_SCENARIOS[:n]]

    results: List[BenchResult] = []

    for scenario in scenarios:
        env = self._make_env(scenario)
        t0 = time.perf_counter()
        episodes = self._evaluate(policy, env, scenario)
        elapsed = time.perf_counter() - t0
        stats = _aggregate_episodes(episodes)
        results.append(
            BenchResult(
                scenario_name=scenario.name,
                policy_label=effective_label,
                elapsed_seconds=elapsed,
                **stats,
            )
        )
        env.close()
        log.info(
            "Scenario %-25s  win_rate=%.1f%%  steps=%.0f",
            scenario.name,
            stats["win_rate"] * 100,
            stats["mean_steps"],
        )

    return BenchSummary(results=results, config=summary_cfg)

benchmarks.wargames_bench.BenchScenario dataclass

Descriptor for one WargamesBench evaluation scenario.

Attributes:

Name Type Description
name str

Unique scenario identifier. Used as the leaderboard row key.

n_blue int

Number of blue (agent) units.

n_red int

Number of red (opponent) units.

weather str

Weather string — "clear", "rain", "fog", or "snow".

terrain_seed int

Seed used to generate the procedural terrain. Fixed per scenario.

seed int

Episode-level RNG seed used for env.reset(seed=…).

n_hills int

Number of terrain hills to generate.

n_forests int

Number of terrain forest patches to generate.

map_width float

Terrain map width in metres.

map_height float

Terrain map height in metres.

Source code in benchmarks/wargames_bench.py
@dataclass
class BenchScenario:
    """Descriptor for one WargamesBench evaluation scenario.

    Attributes
    ----------
    name:
        Unique scenario identifier.  Used as the leaderboard row key.
    n_blue:
        Number of blue (agent) units.
    n_red:
        Number of red (opponent) units.
    weather:
        Weather string — ``"clear"``, ``"rain"``, ``"fog"``, or ``"snow"``.
    terrain_seed:
        Seed used to generate the procedural terrain.  Fixed per scenario.
    seed:
        Episode-level RNG seed used for ``env.reset(seed=…)``.
    n_hills:
        Number of terrain hills to generate.
    n_forests:
        Number of terrain forest patches to generate.
    map_width:
        Terrain map width in metres.
    map_height:
        Terrain map height in metres.
    """

    name: str
    n_blue: int = 8
    n_red: int = 8
    weather: str = "clear"
    terrain_seed: int = 42
    seed: int = 0
    n_hills: int = 4
    n_forests: int = 3
    map_width: float = 10_000.0
    map_height: float = 10_000.0

    @classmethod
    def from_dict(cls, d: Dict[str, Any]) -> "BenchScenario":
        """Build a :class:`BenchScenario` from a registry entry dict."""
        return cls(
            name=d["name"],
            n_blue=d.get("n_blue", 8),
            n_red=d.get("n_red", 8),
            weather=d.get("weather", "clear"),
            terrain_seed=d.get("terrain_seed", 42),
            seed=d.get("seed", 0),
            n_hills=d.get("n_hills", 4),
            n_forests=d.get("n_forests", 3),
            map_width=d.get("map_width", 10_000.0),
            map_height=d.get("map_height", 10_000.0),
        )

from_dict(d) classmethod

Build a :class:BenchScenario from a registry entry dict.

Source code in benchmarks/wargames_bench.py
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "BenchScenario":
    """Build a :class:`BenchScenario` from a registry entry dict."""
    return cls(
        name=d["name"],
        n_blue=d.get("n_blue", 8),
        n_red=d.get("n_red", 8),
        weather=d.get("weather", "clear"),
        terrain_seed=d.get("terrain_seed", 42),
        seed=d.get("seed", 0),
        n_hills=d.get("n_hills", 4),
        n_forests=d.get("n_forests", 3),
        map_width=d.get("map_width", 10_000.0),
        map_height=d.get("map_height", 10_000.0),
    )

benchmarks.wargames_bench.BenchConfig dataclass

Configuration for a WargamesBench run.

Attributes:

Name Type Description
n_eval_episodes int

Episodes per scenario. Use ≥ 100 for reproducible ± 2 % win rates.

n_scenarios int

Number of canonical scenarios to evaluate (≤ 20).

max_steps_per_episode int

Hard episode-length cap.

baseline_label str

Human-readable label for the evaluated policy in the leaderboard.

report_path Optional[str]

Where to write the Markdown leaderboard report. None → default.

win_rate_tolerance float

Maximum allowed win-rate deviation between runs with the same seed. Used by :meth:BenchSummary.is_reproducible.

Source code in benchmarks/wargames_bench.py
@dataclass
class BenchConfig:
    """Configuration for a WargamesBench run.

    Attributes
    ----------
    n_eval_episodes:
        Episodes per scenario.  Use ≥ 100 for reproducible ± 2 % win rates.
    n_scenarios:
        Number of canonical scenarios to evaluate (≤ 20).
    max_steps_per_episode:
        Hard episode-length cap.
    baseline_label:
        Human-readable label for the evaluated policy in the leaderboard.
    report_path:
        Where to write the Markdown leaderboard report.  ``None`` → default.
    win_rate_tolerance:
        Maximum allowed win-rate deviation between runs with the same seed.
        Used by :meth:`BenchSummary.is_reproducible`.
    """

    n_eval_episodes: int = 100
    n_scenarios: int = 20
    max_steps_per_episode: int = 500
    baseline_label: str = "scripted_baseline"
    report_path: Optional[str] = None
    win_rate_tolerance: float = 0.02

benchmarks.wargames_bench.BenchResult dataclass

Win-rate statistics for one (policy × scenario) evaluation.

Attributes:

Name Type Description
scenario_name str

Name of the evaluated scenario.

policy_label str

Human-readable label for the policy.

win_rate float

Fraction of episodes won by the blue (agent) policy.

mean_steps float

Mean episode length.

std_steps float

Standard deviation of episode length.

n_episodes int

Number of episodes evaluated.

elapsed_seconds float

Wall-clock seconds for this evaluation.

Source code in benchmarks/wargames_bench.py
@dataclass
class BenchResult:
    """Win-rate statistics for one (policy × scenario) evaluation.

    Attributes
    ----------
    scenario_name:
        Name of the evaluated scenario.
    policy_label:
        Human-readable label for the policy.
    win_rate:
        Fraction of episodes won by the blue (agent) policy.
    mean_steps:
        Mean episode length.
    std_steps:
        Standard deviation of episode length.
    n_episodes:
        Number of episodes evaluated.
    elapsed_seconds:
        Wall-clock seconds for this evaluation.
    """

    scenario_name: str
    policy_label: str
    win_rate: float
    mean_steps: float
    std_steps: float
    n_episodes: int
    elapsed_seconds: float = 0.0

benchmarks.wargames_bench.BenchSummary dataclass

Aggregated WargamesBench results across all evaluated scenarios.

Attributes:

Name Type Description
results List[BenchResult]

Per-scenario :class:BenchResult objects.

config BenchConfig

:class:BenchConfig used for this run.

Source code in benchmarks/wargames_bench.py
@dataclass
class BenchSummary:
    """Aggregated WargamesBench results across all evaluated scenarios.

    Attributes
    ----------
    results:
        Per-scenario :class:`BenchResult` objects.
    config:
        :class:`BenchConfig` used for this run.
    """

    results: List[BenchResult]
    config: BenchConfig

    # ------------------------------------------------------------------
    # Aggregated statistics
    # ------------------------------------------------------------------

    @property
    def mean_win_rate(self) -> float:
        """Mean win rate across all evaluated scenarios."""
        if not self.results:
            return 0.0
        return float(np.mean([r.win_rate for r in self.results]))

    @property
    def std_win_rate(self) -> float:
        """Standard deviation of win rates across scenarios."""
        if not self.results:
            return 0.0
        return float(np.std([r.win_rate for r in self.results]))

    @property
    def total_episodes(self) -> int:
        """Total number of episodes evaluated."""
        return sum(r.n_episodes for r in self.results)

    @property
    def total_elapsed_seconds(self) -> float:
        """Total wall-clock seconds for all evaluations."""
        return sum(r.elapsed_seconds for r in self.results)

    def is_reproducible(self, other: "BenchSummary") -> bool:
        """Return ``True`` when *self* and *other* win rates agree within tolerance.

        Both runs must have identical scenario names in the same order.  The
        tolerance is ``self.config.win_rate_tolerance`` (default 2 %).
        """
        if len(self.results) != len(other.results):
            return False
        for a, b in zip(self.results, other.results):
            if a.scenario_name != b.scenario_name:
                return False
            if abs(a.win_rate - b.win_rate) > self.config.win_rate_tolerance:
                return False
        return True

    # ------------------------------------------------------------------
    # Rendering
    # ------------------------------------------------------------------

    def __str__(self) -> str:
        lines = [
            "WargamesBench Summary",
            f"  Policy               : {self.config.baseline_label}",
            f"  Scenarios evaluated  : {len(self.results)}",
            f"  Total episodes       : {self.total_episodes}",
            f"  Mean win rate        : {self.mean_win_rate:.1%}",
            f"  Std win rate         : {self.std_win_rate:.1%}",
            f"  Elapsed              : {self.total_elapsed_seconds:.1f}s",
        ]
        return "\n".join(lines)

    def to_leaderboard_row(self) -> Dict[str, Any]:
        """Return a dict suitable for appending to a leaderboard table."""
        return {
            "policy": self.config.baseline_label,
            "mean_win_rate": f"{self.mean_win_rate:.3f}",
            "std_win_rate": f"{self.std_win_rate:.3f}",
            "n_scenarios": len(self.results),
            "total_episodes": self.total_episodes,
        }

    def write_markdown(self, path: Optional[str | Path] = None) -> Path:
        """Write a Markdown leaderboard report to *path*.

        Defaults to ``docs/wargames_bench_leaderboard.md``.
        """
        if path is None:
            if self.config.report_path:
                out = Path(self.config.report_path)
            else:
                out = _REPO_ROOT / "docs" / "wargames_bench_leaderboard.md"
        else:
            out = Path(path)
        out.parent.mkdir(parents=True, exist_ok=True)
        out.write_text(_render_summary_markdown(self), encoding="utf-8")
        return out

mean_win_rate property

Mean win rate across all evaluated scenarios.

std_win_rate property

Standard deviation of win rates across scenarios.

total_elapsed_seconds property

Total wall-clock seconds for all evaluations.

total_episodes property

Total number of episodes evaluated.

is_reproducible(other)

Return True when self and other win rates agree within tolerance.

Both runs must have identical scenario names in the same order. The tolerance is self.config.win_rate_tolerance (default 2 %).

Source code in benchmarks/wargames_bench.py
def is_reproducible(self, other: "BenchSummary") -> bool:
    """Return ``True`` when *self* and *other* win rates agree within tolerance.

    Both runs must have identical scenario names in the same order.  The
    tolerance is ``self.config.win_rate_tolerance`` (default 2 %).
    """
    if len(self.results) != len(other.results):
        return False
    for a, b in zip(self.results, other.results):
        if a.scenario_name != b.scenario_name:
            return False
        if abs(a.win_rate - b.win_rate) > self.config.win_rate_tolerance:
            return False
    return True

to_leaderboard_row()

Return a dict suitable for appending to a leaderboard table.

Source code in benchmarks/wargames_bench.py
def to_leaderboard_row(self) -> Dict[str, Any]:
    """Return a dict suitable for appending to a leaderboard table."""
    return {
        "policy": self.config.baseline_label,
        "mean_win_rate": f"{self.mean_win_rate:.3f}",
        "std_win_rate": f"{self.std_win_rate:.3f}",
        "n_scenarios": len(self.results),
        "total_episodes": self.total_episodes,
    }

write_markdown(path=None)

Write a Markdown leaderboard report to path.

Defaults to docs/wargames_bench_leaderboard.md.

Source code in benchmarks/wargames_bench.py
def write_markdown(self, path: Optional[str | Path] = None) -> Path:
    """Write a Markdown leaderboard report to *path*.

    Defaults to ``docs/wargames_bench_leaderboard.md``.
    """
    if path is None:
        if self.config.report_path:
            out = Path(self.config.report_path)
        else:
            out = _REPO_ROOT / "docs" / "wargames_bench_leaderboard.md"
    else:
        out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(_render_summary_markdown(self), encoding="utf-8")
    return out

benchmarks.wargames_bench.BENCH_SCENARIOS = [{'name': 'sym_4v4_clear', 'n_blue': 4, 'n_red': 4, 'weather': 'clear', 'terrain_seed': 4201, 'seed': 9001}, {'name': 'sym_6v6_rain', 'n_blue': 6, 'n_red': 6, 'weather': 'rain', 'terrain_seed': 4202, 'seed': 9002}, {'name': 'sym_8v8_fog', 'n_blue': 8, 'n_red': 8, 'weather': 'fog', 'terrain_seed': 4203, 'seed': 9003}, {'name': 'sym_8v8_snow', 'n_blue': 8, 'n_red': 8, 'weather': 'snow', 'terrain_seed': 4204, 'seed': 9004}, {'name': 'sym_12v12_clear', 'n_blue': 12, 'n_red': 12, 'weather': 'clear', 'terrain_seed': 4205, 'seed': 9005}, {'name': 'asym_4v8_clear', 'n_blue': 4, 'n_red': 8, 'weather': 'clear', 'terrain_seed': 4206, 'seed': 9006}, {'name': 'asym_8v4_rain', 'n_blue': 8, 'n_red': 4, 'weather': 'rain', 'terrain_seed': 4207, 'seed': 9007}, {'name': 'asym_6v12_fog', 'n_blue': 6, 'n_red': 12, 'weather': 'fog', 'terrain_seed': 4208, 'seed': 9008}, {'name': 'asym_12v6_snow', 'n_blue': 12, 'n_red': 6, 'weather': 'snow', 'terrain_seed': 4209, 'seed': 9009}, {'name': 'hilly_6v6_clear', 'n_blue': 6, 'n_red': 6, 'weather': 'clear', 'terrain_seed': 5001, 'seed': 9010, 'n_hills': 8, 'n_forests': 2}, {'name': 'hilly_8v8_rain', 'n_blue': 8, 'n_red': 8, 'weather': 'rain', 'terrain_seed': 5002, 'seed': 9011, 'n_hills': 8, 'n_forests': 2}, {'name': 'forest_6v6_fog', 'n_blue': 6, 'n_red': 6, 'weather': 'fog', 'terrain_seed': 6001, 'seed': 9012, 'n_hills': 2, 'n_forests': 8}, {'name': 'forest_8v8_snow', 'n_blue': 8, 'n_red': 8, 'weather': 'snow', 'terrain_seed': 6002, 'seed': 9013, 'n_hills': 2, 'n_forests': 8}, {'name': 'large_16v16_clear', 'n_blue': 16, 'n_red': 16, 'weather': 'clear', 'terrain_seed': 7001, 'seed': 9014}, {'name': 'large_16v12_rain', 'n_blue': 16, 'n_red': 12, 'weather': 'rain', 'terrain_seed': 7002, 'seed': 9015}, {'name': 'skirmish_3v3_clear', 'n_blue': 3, 'n_red': 3, 'weather': 'clear', 'terrain_seed': 8001, 'seed': 9016}, {'name': 'skirmish_3v3_fog', 'n_blue': 3, 'n_red': 3, 'weather': 'fog', 'terrain_seed': 8002, 'seed': 9017}, {'name': 'defense_6v8_clear', 'n_blue': 6, 'n_red': 8, 'weather': 'clear', 'terrain_seed': 8003, 'seed': 9018, 'n_hills': 6, 'n_forests': 4}, {'name': 'defense_4v6_snow', 'n_blue': 4, 'n_red': 6, 'weather': 'snow', 'terrain_seed': 8004, 'seed': 9019, 'n_hills': 6, 'n_forests': 4}, {'name': 'lowvis_8v8_fog', 'n_blue': 8, 'n_red': 8, 'weather': 'fog', 'terrain_seed': 8005, 'seed': 9020}] module-attribute