Skip to content

Training API

The training package exposes the full training workflow as a clean programmatic Python API. All stable symbols are importable from the top-level training namespace.

Quick-start

from training import train, TrainingConfig

# Minimal run with defaults
model = train(total_timesteps=500_000, n_envs=4, enable_wandb=False)

# Full config object
config = TrainingConfig(
    total_timesteps=1_000_000,
    n_envs=8,
    curriculum_level=3,
    enable_self_play=True,
    wandb_project="my_project",
)
model = train(config)

Training runners

training.train.TrainingConfig dataclass

Configuration for a single PPO training run on :class:~envs.battalion_env.BattalionEnv.

All fields are optional; the defaults match configs/default.yaml. Instances can be passed directly to :func:train or individual fields can be overridden via **kwargs at the call site.

Examples:

Minimal run with defaults::

from training import train, TrainingConfig
model = train(TrainingConfig(total_timesteps=500_000))

Override specific fields at call time without constructing a config::

model = train(total_timesteps=200_000, n_envs=4, enable_wandb=False)
Source code in training/train.py
@dataclass
class TrainingConfig:
    """Configuration for a single PPO training run on :class:`~envs.battalion_env.BattalionEnv`.

    All fields are optional; the defaults match ``configs/default.yaml``.
    Instances can be passed directly to :func:`train` or individual fields
    can be overridden via ``**kwargs`` at the call site.

    Examples
    --------
    Minimal run with defaults::

        from training import train, TrainingConfig
        model = train(TrainingConfig(total_timesteps=500_000))

    Override specific fields at call time without constructing a config::

        model = train(total_timesteps=200_000, n_envs=4, enable_wandb=False)
    """

    # ── PPO hyperparameters ───────────────────────────────────────────────
    total_timesteps: int = 1_000_000
    learning_rate: float = 3.0e-4
    n_steps: int = 2048
    batch_size: int = 64
    n_epochs: int = 10
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_range: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5
    seed: int = 42
    device: str = "auto"

    # ── Environment ───────────────────────────────────────────────────────
    n_envs: int = 8
    curriculum_level: int = 5
    map_width: float = 1000.0
    map_height: float = 1000.0
    max_steps: int = 500
    randomize_terrain: bool = True
    hill_speed_factor: float = 0.5

    # ── Reward weights ────────────────────────────────────────────────────
    reward_delta_enemy_strength: float = 5.0
    reward_delta_own_strength: float = 5.0
    reward_survival_bonus: float = 0.0
    reward_win_bonus: float = 10.0
    reward_loss_penalty: float = -10.0
    reward_time_penalty: float = -0.01

    # ── Checkpointing and evaluation ──────────────────────────────────────
    checkpoint_dir: str = "checkpoints"
    checkpoint_freq: int = 100_000
    eval_freq: int = 50_000
    n_eval_episodes: int = 20
    eval_deterministic: bool = True

    # ── Artifact management ───────────────────────────────────────────────
    write_manifest: bool = True
    manifest_path: str = "checkpoints/manifest.jsonl"
    enable_naming_v2: bool = True
    keep_legacy_aliases: bool = True

    # ── W&B experiment tracking ───────────────────────────────────────────
    enable_wandb: bool = True
    wandb_project: str = "wargames_training"
    wandb_entity: Optional[str] = None
    wandb_tags: List[str] = field(default_factory=_default_wandb_tags)
    wandb_log_freq: int = 1000

    # ── Self-play (disabled by default) ───────────────────────────────────
    enable_self_play: bool = False
    self_play_pool_dir: str = "checkpoints/pool"
    self_play_pool_max_size: int = 10
    self_play_snapshot_freq: int = 50_000
    self_play_eval_freq: int = 50_000
    self_play_n_eval_episodes: int = 20
    self_play_use_latest_for_eval: bool = False

    # ── Elo evaluation (disabled by default) ─────────────────────────────
    elo_opponents: List[str] = field(default_factory=list)
    elo_registry_path: str = "checkpoints/elo_registry.json"
    elo_eval_freq: int = 50_000
    elo_n_eval_episodes: int = 20

    # ── Retention / pruning ───────────────────────────────────────────────
    keep_periodic: int = 5
    keep_self_play_snapshots: int = 10
    prune_on_run_end: bool = True

    # ── Logging ───────────────────────────────────────────────────────────
    verbose: int = 1
    log_dir: str = "logs"

training.train.train(config=None, *, extra_callbacks=None, resume=None, **override_kwargs)

Train a PPO policy on :class:~envs.battalion_env.BattalionEnv.

This is the programmatic entry-point for training, fully decoupled from Hydra/YAML so it can be called from any Python script or notebook.

Parameters:

Name Type Description Default
config Optional[TrainingConfig]

:class:TrainingConfig instance. When None a default config is used. Individual fields can be overridden via **override_kwargs.

None
extra_callbacks Optional[List[BaseCallback]]

Additional SB3 :class:~stable_baselines3.common.callbacks.BaseCallback instances appended to the built-in callback list.

None
resume Optional[Union[str, Path]]

Path to an existing .zip checkpoint to resume training from. When provided, the model weights and optimizer state are loaded before training begins (the .zip extension may be omitted).

None
**override_kwargs Any

Keyword arguments that override individual :class:TrainingConfig fields. Any unrecognised key raises :exc:ValueError.

{}

Returns:

Type Description
PPO

The trained model. Periodic checkpoints, best-model, and a manifest are written to config.checkpoint_dir during training.

Raises:

Type Description
ValueError

If any configuration value is invalid or an unrecognised **override_kwarg key is passed.

FileNotFoundError

If resume points to a path that does not exist on disk.

Examples:

Quickstart with defaults::

from training import train
model = train(total_timesteps=200_000, n_envs=4, enable_wandb=False)

Full config::

from training import train, TrainingConfig
config = TrainingConfig(
    total_timesteps=1_000_000,
    n_envs=8,
    curriculum_level=3,
    enable_self_play=True,
)
model = train(config)
Source code in training/train.py
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
def train(
    config: Optional[TrainingConfig] = None,
    *,
    extra_callbacks: Optional[List[BaseCallback]] = None,
    resume: Optional[Union[str, Path]] = None,
    **override_kwargs: Any,
) -> PPO:
    """Train a PPO policy on :class:`~envs.battalion_env.BattalionEnv`.

    This is the programmatic entry-point for training, fully decoupled from
    Hydra/YAML so it can be called from any Python script or notebook.

    Parameters
    ----------
    config:
        :class:`TrainingConfig` instance.  When ``None`` a default config is
        used.  Individual fields can be overridden via ``**override_kwargs``.
    extra_callbacks:
        Additional SB3 :class:`~stable_baselines3.common.callbacks.BaseCallback`
        instances appended to the built-in callback list.
    resume:
        Path to an existing ``.zip`` checkpoint to resume training from.
        When provided, the model weights and optimizer state are loaded
        before training begins (the ``.zip`` extension may be omitted).
    **override_kwargs:
        Keyword arguments that override individual :class:`TrainingConfig`
        fields.  Any unrecognised key raises :exc:`ValueError`.

    Returns
    -------
    stable_baselines3.PPO
        The trained model.  Periodic checkpoints, best-model, and a manifest
        are written to *config.checkpoint_dir* during training.

    Raises
    ------
    ValueError
        If any configuration value is invalid or an unrecognised
        ``**override_kwarg`` key is passed.
    FileNotFoundError
        If *resume* points to a path that does not exist on disk.

    Examples
    --------
    Quickstart with defaults::

        from training import train
        model = train(total_timesteps=200_000, n_envs=4, enable_wandb=False)

    Full config::

        from training import train, TrainingConfig
        config = TrainingConfig(
            total_timesteps=1_000_000,
            n_envs=8,
            curriculum_level=3,
            enable_self_play=True,
        )
        model = train(config)
    """
    if config is None:
        config = TrainingConfig()

    # Apply per-call overrides.
    if override_kwargs:
        valid_fields = {f.name for f in dataclasses.fields(config)}
        unknown = set(override_kwargs) - valid_fields
        if unknown:
            raise ValueError(
                f"Unknown TrainingConfig fields: {', '.join(sorted(unknown))}. "
                f"Valid fields: {', '.join(sorted(valid_fields))}."
            )
        config = dataclasses.replace(config, **override_kwargs)

    # Validate critical parameters.
    if config.total_timesteps < 1:
        raise ValueError(
            f"total_timesteps must be >= 1, got {config.total_timesteps}."
        )
    if config.n_envs < 1:
        raise ValueError(f"n_envs must be >= 1, got {config.n_envs}.")
    if config.checkpoint_freq <= 0:
        raise ValueError(
            f"checkpoint_freq must be > 0, got {config.checkpoint_freq}."
        )
    if config.eval_freq <= 0:
        raise ValueError(f"eval_freq must be > 0, got {config.eval_freq}.")

    # Resolve paths relative to the current working directory.
    checkpoint_dir = Path(config.checkpoint_dir)
    log_dir = Path(config.log_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    log_dir.mkdir(parents=True, exist_ok=True)

    # Checkpoint manifest.
    manifest_path_resolved = Path(config.manifest_path)
    manifest: Optional[CheckpointManifest] = (
        CheckpointManifest(manifest_path_resolved) if config.write_manifest else None
    )

    # Deterministic config hash for traceability.
    config_dict = dataclasses.asdict(config)
    config_hash = hashlib.sha256(
        json.dumps(config_dict, sort_keys=True, separators=(",", ":")).encode("utf-8")
    ).hexdigest()

    periodic_prefix = checkpoint_name_prefix(
        seed=config.seed,
        curriculum_level=config.curriculum_level,
        enable_v2=config.enable_naming_v2,
    )

    # W&B initialisation (optional; failures are non-fatal).
    run: Any = None
    run_id: Optional[str] = None
    if config.enable_wandb:
        try:
            run = wandb.init(
                project=config.wandb_project,
                entity=config.wandb_entity,
                config=config_dict,
                tags=list(config.wandb_tags),
                sync_tensorboard=False,
                reinit=True,
            )
            run_id = (
                run.id
                if run is not None and hasattr(run, "id") and run.id
                else None
            )
            log.info("W&B run: %s", getattr(run, "url", "offline"))
        except Exception as exc:  # noqa: BLE001
            log.warning(
                "W&B initialisation failed (%s) — continuing without W&B.", exc
            )

    # Reward weights.
    reward_weights = RewardWeights(
        delta_enemy_strength=config.reward_delta_enemy_strength,
        delta_own_strength=config.reward_delta_own_strength,
        survival_bonus=config.reward_survival_bonus,
        win_bonus=config.reward_win_bonus,
        loss_penalty=config.reward_loss_penalty,
        time_penalty=config.reward_time_penalty,
    )

    env_kwargs: dict = dict(
        map_width=config.map_width,
        map_height=config.map_height,
        max_steps=config.max_steps,
        randomize_terrain=config.randomize_terrain,
        hill_speed_factor=config.hill_speed_factor,
        curriculum_level=config.curriculum_level,
        reward_weights=reward_weights,
    )

    vec_env = make_vec_env(
        BattalionEnv,
        n_envs=config.n_envs,
        seed=config.seed,
        env_kwargs=env_kwargs,
    )
    eval_env = make_vec_env(
        BattalionEnv,
        n_envs=1,
        seed=config.seed + 1000,
        env_kwargs=env_kwargs,
    )

    try:
        # Built-in callbacks.
        _checkpoint_cb = ManifestCheckpointCallback(
            save_freq=max(1, config.checkpoint_freq // config.n_envs),
            save_path=str(checkpoint_dir),
            name_prefix=periodic_prefix,
            manifest=manifest,
            seed=config.seed,
            curriculum_level=config.curriculum_level,
            run_id=run_id,
            config_hash=config_hash,
            verbose=config.verbose,
        )
        _eval_cb = ManifestEvalCallback(
            eval_env,
            best_model_save_path=str(checkpoint_dir / "best"),
            log_path=str(log_dir),
            eval_freq=max(1, config.eval_freq // config.n_envs),
            n_eval_episodes=config.n_eval_episodes,
            deterministic=config.eval_deterministic,
            manifest=manifest,
            seed=config.seed,
            curriculum_level=config.curriculum_level,
            run_id=run_id,
            config_hash=config_hash,
            enable_naming_v2=config.enable_naming_v2,
            verbose=config.verbose,
        )
        all_callbacks: list = [_checkpoint_cb, _eval_cb]

        if config.enable_wandb:
            all_callbacks.append(WandbCallback(log_freq=config.wandb_log_freq))
            all_callbacks.append(RewardBreakdownCallback(log_freq=config.wandb_log_freq))

        # Self-play callbacks (optional).
        if config.enable_self_play:
            _pool = OpponentPool(
                pool_dir=Path(config.self_play_pool_dir),
                max_size=config.self_play_pool_max_size,
            )
            _sp_cb = SelfPlayCallback(
                pool=_pool,
                snapshot_freq=config.self_play_snapshot_freq,
                vec_env=vec_env,
                verbose=config.verbose,
                manifest=manifest,
                seed=config.seed,
                curriculum_level=config.curriculum_level,
                run_id=run_id,
                config_hash=config_hash,
            )
            _wr_cb = WinRateVsPoolCallback(
                pool=_pool,
                eval_freq=config.self_play_eval_freq,
                n_eval_episodes=config.self_play_n_eval_episodes,
                deterministic=True,
                use_latest=config.self_play_use_latest_for_eval,
                verbose=config.verbose,
            )
            all_callbacks.extend([_sp_cb, _wr_cb])
            log.info(
                "Self-play enabled: pool_dir=%s, max_size=%d",
                config.self_play_pool_dir,
                config.self_play_pool_max_size,
            )

        # Elo callbacks (optional).
        if config.elo_opponents:
            _elo_registry = EloRegistry(path=Path(config.elo_registry_path))
            _elo_run_id = run_id or f"run_seed{config.seed}"
            _elo_cb = EloEvalCallback(
                opponents=list(config.elo_opponents),
                n_eval_episodes=config.elo_n_eval_episodes,
                registry=_elo_registry,
                agent_name=_elo_run_id,
                eval_freq=config.elo_eval_freq,
                env_kwargs=dict(env_kwargs),
                seed=config.seed,
                verbose=config.verbose,
            )
            all_callbacks.append(_elo_cb)

        # Merge extra caller-supplied callbacks.
        all_callbacks.extend(extra_callbacks or [])

        # Resolve resume checkpoint.
        resume_path: Optional[Path] = None
        if resume is not None:
            resume_path = Path(resume)
            if not resume_path.exists():
                zip_path = Path(str(resume_path) + ".zip")
                if zip_path.exists():
                    resume_path = zip_path
                else:
                    raise FileNotFoundError(
                        f"Resume checkpoint not found: '{resume}'. "
                        "Provide an existing .zip path or omit the extension."
                    )

        # Build or reload PPO model.
        if resume_path is not None:
            log.info("Resuming from checkpoint: %s", resume_path)
            model = PPO.load(
                str(resume_path),
                env=vec_env,
                device=config.device,
                custom_objects={
                    "learning_rate": config.learning_rate,
                    "clip_range": config.clip_range,
                },
            )
        else:
            model = PPO(
                BattalionMlpPolicy,
                vec_env,
                learning_rate=config.learning_rate,
                n_steps=config.n_steps,
                batch_size=config.batch_size,
                n_epochs=config.n_epochs,
                gamma=config.gamma,
                gae_lambda=config.gae_lambda,
                clip_range=config.clip_range,
                ent_coef=config.ent_coef,
                vf_coef=config.vf_coef,
                max_grad_norm=config.max_grad_norm,
                seed=config.seed,
                device=config.device,
                verbose=config.verbose,
            )
        log.info("PPO model ready. Training for %d timesteps.", config.total_timesteps)

        # Training loop.
        model.learn(
            total_timesteps=config.total_timesteps,
            callback=CallbackList(all_callbacks),
            progress_bar=False,
            reset_num_timesteps=resume_path is None,
        )

        # Save final checkpoint.
        final_stem = checkpoint_final_stem(
            seed=config.seed,
            curriculum_level=config.curriculum_level,
            enable_v2=config.enable_naming_v2,
        )
        final_path = checkpoint_dir / final_stem
        model.save(str(final_path))
        log.info("Saved final model to %s.zip", final_path)

        legacy_alias = checkpoint_dir / "ppo_battalion_final"
        if config.keep_legacy_aliases and final_path != legacy_alias:
            model.save(str(legacy_alias))

        # Register artifacts in the manifest.
        if manifest is not None:
            for periodic_zip in checkpoint_dir.glob(f"{periodic_prefix}_*_steps.zip"):
                manifest.register(
                    periodic_zip,
                    artifact_type="periodic",
                    seed=config.seed,
                    curriculum_level=config.curriculum_level,
                    run_id=run_id,
                    config_hash=config_hash,
                    step=parse_step_from_checkpoint_name(periodic_zip),
                )
            final_zip = final_path.with_suffix(".zip")
            if final_zip.exists():
                manifest.register(
                    final_zip,
                    artifact_type="final",
                    seed=config.seed,
                    curriculum_level=config.curriculum_level,
                    run_id=run_id,
                    config_hash=config_hash,
                    step=int(getattr(model, "num_timesteps", 0) or 0),
                )
            # Prune old checkpoints if requested.
            if config.prune_on_run_end and config.keep_periodic > 0:
                pruned = manifest.prune_periodic(
                    checkpoint_dir, periodic_prefix, keep_last=config.keep_periodic
                )
                if pruned:
                    log.info("Pruned %d old periodic checkpoint(s).", len(pruned))
            if config.enable_self_play and config.keep_self_play_snapshots > 0:
                pruned_sp = manifest.prune_self_play_snapshots(
                    Path(config.self_play_pool_dir),
                    keep_last=config.keep_self_play_snapshots,
                )
                if pruned_sp:
                    log.info("Pruned %d old self-play snapshot(s).", len(pruned_sp))

        # Upload final artifact to W&B and close the run.
        if run is not None:
            artifact = wandb.Artifact(name="ppo_battalion_final", type="model")
            zip_str = str(final_path) + ".zip"
            if Path(zip_str).exists():
                artifact.add_file(zip_str)
                run.log_artifact(artifact)
            run.finish()
            run = None  # prevent double-finish in finally

    finally:
        vec_env.close()
        eval_env.close()
        # Ensure W&B run is finished even when an exception interrupts training.
        if run is not None:
            try:
                run.finish()
            except Exception as exc:  # noqa: BLE001
                log.warning("W&B run.finish() raised an error during cleanup: %s", exc)

    return model

Callbacks

training.train.WandbCallback

Bases: BaseCallback

Logs SB3 training metrics to an active W&B run.

Emits episode-level rollout statistics (mean reward and episode length) every log_freq environment steps, and policy-update losses (if available from the SB3 logger) at the end of each rollout.

Parameters:

Name Type Description Default
log_freq int

How often (in environment steps) to log rollout statistics.

1000
verbose int

Verbosity level (0 = silent, 1 = info).

0
Source code in training/train.py
class WandbCallback(BaseCallback):
    """Logs SB3 training metrics to an active W&B run.

    Emits episode-level rollout statistics (mean reward and episode length)
    every ``log_freq`` environment steps, and policy-update losses (if
    available from the SB3 logger) at the end of each rollout.

    Parameters
    ----------
    log_freq:
        How often (in environment steps) to log rollout statistics.
    verbose:
        Verbosity level (0 = silent, 1 = info).
    """

    def __init__(self, log_freq: int = 1000, verbose: int = 0) -> None:
        super().__init__(verbose)
        self.log_freq = log_freq

    def _on_step(self) -> bool:
        if self.num_timesteps % self.log_freq == 0 and len(self.model.ep_info_buffer) > 0:
            ep_infos = list(self.model.ep_info_buffer)
            mean_reward = float(np.mean([ep["r"] for ep in ep_infos]))
            mean_length = float(np.mean([ep["l"] for ep in ep_infos]))
            wandb.log(
                {
                    "rollout/ep_rew_mean": mean_reward,
                    "rollout/ep_len_mean": mean_length,
                    "time/total_timesteps": self.num_timesteps,
                },
                step=self.num_timesteps,
            )
        return True

    def _on_rollout_end(self) -> None:
        """Log policy-update losses after each PPO update."""
        logger_kvs: dict = self.model.logger.name_to_value  # type: ignore[attr-defined]
        if logger_kvs:
            wandb.log(
                {f"train/{k}": v for k, v in logger_kvs.items()},
                step=self.num_timesteps,
            )

training.train.RewardBreakdownCallback

Bases: BaseCallback

Logs per-component reward breakdown to W&B at episode boundaries.

Accumulates reward components from info dicts (populated by :class:~envs.battalion_env.BattalionEnv) across all parallel environments every step and rolls them into per-episode totals when an episode ends. The episode means are logged to W&B every log_freq timesteps. Any remaining episodes at the end of training are flushed in _on_training_end().

Parameters:

Name Type Description Default
log_freq int

How often (in environment steps) to flush accumulated episode means to W&B.

1000
verbose int

Verbosity level (0 = silent, 1 = info).

0
Source code in training/train.py
class RewardBreakdownCallback(BaseCallback):
    """Logs per-component reward breakdown to W&B at episode boundaries.

    Accumulates reward components from ``info`` dicts (populated by
    :class:`~envs.battalion_env.BattalionEnv`) across all parallel
    environments every step and rolls them into per-episode totals when
    an episode ends.  The episode means are logged to W&B every
    ``log_freq`` timesteps.  Any remaining episodes at the end of
    training are flushed in ``_on_training_end()``.

    Parameters
    ----------
    log_freq:
        How often (in environment steps) to flush accumulated episode
        means to W&B.
    verbose:
        Verbosity level (0 = silent, 1 = info).
    """

    _COMPONENT_KEYS: tuple[str, ...] = (
        "reward/delta_enemy_strength",
        "reward/delta_own_strength",
        "reward/survival_bonus",
        "reward/win_bonus",
        "reward/loss_penalty",
        "reward/time_penalty",
        "reward/total",
    )

    def __init__(self, log_freq: int = 1000, verbose: int = 0) -> None:
        super().__init__(verbose)
        self.log_freq = log_freq
        # Per-env step accumulators, indexed by env index.  Initialised in
        # _on_training_start() once the number of parallel envs is known.
        self._step_sums: list[dict[str, float]] = []
        # Completed-episode accumulators (sum across episodes and episode count).
        self._ep_sums: dict[str, float] = {k: 0.0 for k in self._COMPONENT_KEYS}
        self._ep_count: int = 0

    def _on_training_start(self) -> None:
        """Initialise per-env step accumulators once the env count is known."""
        n_envs = self.training_env.num_envs  # type: ignore[union-attr]
        self._step_sums = [
            {k: 0.0 for k in self._COMPONENT_KEYS} for _ in range(n_envs)
        ]

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", [])
        dones = self.locals.get("dones", np.zeros(len(infos), dtype=bool))

        for env_idx, (info, done) in enumerate(zip(infos, dones)):
            # Accumulate each component value for this step.
            for key in self._COMPONENT_KEYS:
                self._step_sums[env_idx][key] += float(info.get(key, 0.0))

            if done:
                # Episode complete — transfer step sums to episode accumulators.
                for key in self._COMPONENT_KEYS:
                    self._ep_sums[key] += self._step_sums[env_idx][key]
                    self._step_sums[env_idx][key] = 0.0
                self._ep_count += 1

        if self.num_timesteps % self.log_freq == 0 and self._ep_count > 0:
            self._flush()
        return True

    def _on_training_end(self) -> None:
        """Flush any remaining accumulated episodes at the end of training."""
        if self._ep_count > 0:
            self._flush()

    def _flush(self) -> None:
        """Log episode means to W&B and reset accumulators."""
        means = {
            f"reward_breakdown/{k.split('/')[-1]}": v / self._ep_count
            for k, v in self._ep_sums.items()
        }
        means["time/total_timesteps"] = self.num_timesteps
        wandb.log(means, step=self.num_timesteps)
        self._ep_sums = {k: 0.0 for k in self._COMPONENT_KEYS}
        self._ep_count = 0

training.train.EloEvalCallback

Bases: BaseCallback

Evaluate the current policy vs scripted opponents and log Elo to W&B.

Every eval_freq environment steps the callback runs n_eval_episodes episodes against each opponent in opponents using the live model, updates the :class:~training.elo.EloRegistry, persists it to disk, and logs per-opponent Elo ratings and win rates to W&B.

Parameters:

Name Type Description Default
opponents list[str]

List of opponent identifiers (e.g. ["scripted_l1", "scripted_l3", "scripted_l5"]). Each must be a valid argument to :func:~training.evaluate.run_episodes_with_model.

required
n_eval_episodes int

Number of episodes to run per opponent per evaluation.

required
registry EloRegistry

:class:~training.elo.EloRegistry instance used for ratings.

required
agent_name str

Key used to identify this training run in the registry.

required
eval_freq int

How often (in environment steps) to trigger evaluation.

required
env_kwargs Optional[dict]

Keyword arguments forwarded to :class:~envs.battalion_env.BattalionEnv when creating evaluation environments. This ensures Elo evaluation uses the same map size, terrain settings, and reward weights as the training run.

None
seed Optional[int]

Base random seed for evaluation episodes.

None
verbose int

Verbosity level (0 = silent, 1 = info).

0
Source code in training/train.py
class EloEvalCallback(BaseCallback):
    """Evaluate the current policy vs scripted opponents and log Elo to W&B.

    Every ``eval_freq`` environment steps the callback runs *n_eval_episodes*
    episodes against each opponent in *opponents* using the live model,
    updates the :class:`~training.elo.EloRegistry`, persists it to disk, and
    logs per-opponent Elo ratings and win rates to W&B.

    Parameters
    ----------
    opponents:
        List of opponent identifiers (e.g. ``["scripted_l1", "scripted_l3",
        "scripted_l5"]``).  Each must be a valid argument to
        :func:`~training.evaluate.run_episodes_with_model`.
    n_eval_episodes:
        Number of episodes to run per opponent per evaluation.
    registry:
        :class:`~training.elo.EloRegistry` instance used for ratings.
    agent_name:
        Key used to identify this training run in the registry.
    eval_freq:
        How often (in environment steps) to trigger evaluation.
    env_kwargs:
        Keyword arguments forwarded to :class:`~envs.battalion_env.BattalionEnv`
        when creating evaluation environments.  This ensures Elo evaluation
        uses the same map size, terrain settings, and reward weights as the
        training run.
    seed:
        Base random seed for evaluation episodes.
    verbose:
        Verbosity level (0 = silent, 1 = info).
    """

    def __init__(
        self,
        opponents: list[str],
        n_eval_episodes: int,
        registry: EloRegistry,
        agent_name: str,
        eval_freq: int,
        env_kwargs: Optional[dict] = None,
        seed: Optional[int] = None,
        verbose: int = 0,
    ) -> None:
        super().__init__(verbose)
        self.opponents = opponents
        self.n_eval_episodes = n_eval_episodes
        self.registry = registry
        self.agent_name = agent_name
        self.eval_freq = eval_freq
        self.env_kwargs = env_kwargs or {}
        self.seed = seed
        self._last_eval_step: int = 0

    def _on_step(self) -> bool:
        if (
            self.num_timesteps - self._last_eval_step >= self.eval_freq
            and self.num_timesteps > 0
        ):
            self._run_elo_eval()
            self._last_eval_step = self.num_timesteps
        return True

    def _run_elo_eval(self) -> None:
        """Evaluate vs all opponents and update the Elo registry."""
        log_dict: dict = {"time/total_timesteps": self.num_timesteps}
        for opponent in self.opponents:
            result = run_episodes_with_model(
                self.model,
                opponent=opponent,
                n_episodes=self.n_eval_episodes,
                deterministic=True,
                seed=self.seed,
                env_kwargs=self.env_kwargs,
            )
            outcome = (result.wins + 0.5 * result.draws) / result.n_episodes
            delta = self.registry.update(
                agent=self.agent_name,
                opponent=opponent,
                outcome=outcome,
                n_games=result.n_episodes,
            )
            elo_rating = self.registry.get_rating(self.agent_name)
            log_dict[f"elo/rating_vs_{opponent}"] = elo_rating
            log_dict[f"elo/win_rate_vs_{opponent}"] = result.win_rate
            log_dict[f"elo/delta_vs_{opponent}"] = delta
            if self.verbose >= 1:
                log.info(
                    "EloEval [%d steps] vs %s — win %.1f%% Elo %.1f%+.1f)",
                    self.num_timesteps,
                    opponent,
                    result.win_rate * 100,
                    elo_rating,
                    delta,
                )
        # Persist only when the registry has a backing file.
        if self.registry.can_save:
            self.registry.save()
        wandb.log(log_dict, step=self.num_timesteps)

training.train.ManifestCheckpointCallback

Bases: CheckpointCallback

Checkpoint callback that appends periodic saves to the manifest immediately.

Source code in training/train.py
class ManifestCheckpointCallback(CheckpointCallback):
    """Checkpoint callback that appends periodic saves to the manifest immediately."""

    def __init__(
        self,
        *,
        manifest: Optional[CheckpointManifest],
        seed: int,
        curriculum_level: int,
        run_id: Optional[str],
        config_hash: str,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self._manifest = manifest
        self._seed = int(seed)
        self._curriculum_level = int(curriculum_level)
        self._run_id = run_id
        self._config_hash = str(config_hash)

    def _on_step(self) -> bool:
        result = super()._on_step()
        if self._manifest is None:
            return result
        if self.n_calls % self.save_freq != 0:
            return result

        checkpoint_path = Path(self._checkpoint_path(extension="zip"))
        if checkpoint_path.exists():
            self._manifest.register(
                checkpoint_path,
                artifact_type="periodic",
                seed=self._seed,
                curriculum_level=self._curriculum_level,
                run_id=self._run_id,
                config_hash=self._config_hash,
                step=parse_step_from_checkpoint_name(checkpoint_path),
            )
        return result

training.train.ManifestEvalCallback

Bases: EvalCallback

Eval callback that materializes best-model metadata at creation time.

Source code in training/train.py
class ManifestEvalCallback(EvalCallback):
    """Eval callback that materializes best-model metadata at creation time."""

    def __init__(
        self,
        eval_env,
        *,
        manifest: Optional[CheckpointManifest],
        seed: int,
        curriculum_level: int,
        run_id: Optional[str],
        config_hash: str,
        enable_naming_v2: bool,
        **kwargs,
    ) -> None:
        super().__init__(eval_env, **kwargs)
        self._manifest = manifest
        self._seed = int(seed)
        self._curriculum_level = int(curriculum_level)
        self._run_id = run_id
        self._config_hash = str(config_hash)
        self._enable_naming_v2 = bool(enable_naming_v2)

    def _on_step(self) -> bool:
        best_before = self.best_mean_reward
        result = super()._on_step()
        if self._manifest is None:
            return result
        if self.best_mean_reward <= best_before:
            return result

        best_alias_zip = Path(self.best_model_save_path) / "best_model.zip"
        if not best_alias_zip.exists():
            return result

        self._manifest.register(
            best_alias_zip,
            artifact_type="best_alias",
            seed=self._seed,
            curriculum_level=self._curriculum_level,
            run_id=self._run_id,
            config_hash=self._config_hash,
            step=int(self.num_timesteps),
        )

        best_canonical_zip = Path(self.best_model_save_path) / checkpoint_best_filename(
            seed=self._seed,
            curriculum_level=self._curriculum_level,
            enable_v2=self._enable_naming_v2,
        )
        if best_canonical_zip != best_alias_zip:
            best_canonical_zip.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(best_alias_zip, best_canonical_zip)
        if best_canonical_zip.exists():
            self._manifest.register(
                best_canonical_zip,
                artifact_type="best",
                seed=self._seed,
                curriculum_level=self._curriculum_level,
                run_id=self._run_id,
                config_hash=self._config_hash,
                step=int(self.num_timesteps),
            )
        return result

Evaluation

training.evaluate

Evaluate a saved PPO checkpoint against a configurable opponent.

Loads a Stable-Baselines3 PPO model from a .zip checkpoint, runs it against a chosen opponent in :class:~envs.battalion_env.BattalionEnv for a configurable number of episodes, and reports the Blue win rate and optional Elo delta to stdout.

Supported opponent identifiers

scripted_l1scripted_l5 Built-in scripted Red opponent at the specified curriculum level. random A Red opponent that samples uniformly random actions every step. <path> Any file-system path to an SB3 .zip checkpoint; that model drives Red.

A win is defined as Red routing or being destroyed without Blue having routed or been destroyed in the same step. A draw occurs when both sides lose simultaneously or the episode reaches the step limit with neither side eliminated.

Usage::

python training/evaluate.py --checkpoint checkpoints/run/final \
    --opponent scripted_l3
python training/evaluate.py --checkpoint checkpoints/run/final \
    --opponent scripted_l3 --n-episodes 100 --seed 0 \
    --elo-registry checkpoints/elo_registry.json \
    --agent-name my_run_v1

EvaluationResult

Bases: NamedTuple

Structured result from an evaluation run.

Attributes:

Name Type Description
wins int

Number of episodes Blue won.

draws int

Number of episodes that ended as a draw (both sides lost or timeout).

losses int

Number of episodes Blue lost.

n_episodes int

Total episodes evaluated (wins + draws + losses).

win_rate float

wins / n_episodes.

draw_rate float

draws / n_episodes.

loss_rate float

losses / n_episodes.

Source code in training/evaluate.py
class EvaluationResult(NamedTuple):
    """Structured result from an evaluation run.

    Attributes
    ----------
    wins:
        Number of episodes Blue won.
    draws:
        Number of episodes that ended as a draw (both sides lost or timeout).
    losses:
        Number of episodes Blue lost.
    n_episodes:
        Total episodes evaluated (``wins + draws + losses``).
    win_rate:
        ``wins / n_episodes``.
    draw_rate:
        ``draws / n_episodes``.
    loss_rate:
        ``losses / n_episodes``.
    """

    wins: int
    draws: int
    losses: int
    n_episodes: int
    win_rate: float
    draw_rate: float
    loss_rate: float

evaluate(checkpoint_path, n_episodes=50, deterministic=True, seed=None, opponent='scripted_l5')

Load a checkpoint, run n_episodes, and return the Blue win rate.

This is a thin wrapper around :func:evaluate_detailed kept for backward compatibility.

Parameters:

Name Type Description Default
checkpoint_path str

Path to the .zip checkpoint (extension may be omitted).

required
n_episodes int

Number of evaluation episodes (must be ≥ 1).

50
deterministic bool

Whether the policy acts deterministically.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided.

None
opponent str

Opponent identifier — see module docstring for valid values. Defaults to "scripted_l5" (full-combat scripted Red) to match the previous behaviour.

'scripted_l5'

Returns:

Type Description
float

Win rate in [0, 1].

Raises:

Type Description
ValueError

If n_episodes is less than 1.

Source code in training/evaluate.py
def evaluate(
    checkpoint_path: str,
    n_episodes: int = 50,
    deterministic: bool = True,
    seed: Optional[int] = None,
    opponent: str = "scripted_l5",
) -> float:
    """Load a checkpoint, run *n_episodes*, and return the Blue win rate.

    This is a thin wrapper around :func:`evaluate_detailed` kept for
    backward compatibility.

    Parameters
    ----------
    checkpoint_path:
        Path to the ``.zip`` checkpoint (extension may be omitted).
    n_episodes:
        Number of evaluation episodes (must be ≥ 1).
    deterministic:
        Whether the policy acts deterministically.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.
    opponent:
        Opponent identifier — see module docstring for valid values.
        Defaults to ``"scripted_l5"`` (full-combat scripted Red) to match
        the previous behaviour.

    Returns
    -------
    float
        Win rate in ``[0, 1]``.

    Raises
    ------
    ValueError
        If *n_episodes* is less than 1.
    """
    return evaluate_detailed(
        checkpoint_path,
        n_episodes=n_episodes,
        deterministic=deterministic,
        seed=seed,
        opponent=opponent,
    ).win_rate

evaluate_detailed(checkpoint_path, n_episodes=50, deterministic=True, seed=None, opponent='scripted_l5', env_kwargs=None)

Load a checkpoint and run n_episodes, returning a full result struct.

Parameters:

Name Type Description Default
checkpoint_path str

Path to the .zip checkpoint (extension may be omitted).

required
n_episodes int

Number of evaluation episodes (must be ≥ 1).

50
deterministic bool

Whether the policy acts deterministically.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided.

None
opponent str

Opponent identifier — see module docstring for valid values.

'scripted_l5'
env_kwargs Optional[dict]

Extra keyword arguments forwarded to :class:BattalionEnv (e.g. map_width, reward_weights).

None

Returns:

Type Description
EvaluationResult

Raises:

Type Description
ValueError

If n_episodes < 1.

Source code in training/evaluate.py
def evaluate_detailed(
    checkpoint_path: str,
    n_episodes: int = 50,
    deterministic: bool = True,
    seed: Optional[int] = None,
    opponent: str = "scripted_l5",
    env_kwargs: Optional[dict] = None,
) -> EvaluationResult:
    """Load a checkpoint and run *n_episodes*, returning a full result struct.

    Parameters
    ----------
    checkpoint_path:
        Path to the ``.zip`` checkpoint (extension may be omitted).
    n_episodes:
        Number of evaluation episodes (must be ≥ 1).
    deterministic:
        Whether the policy acts deterministically.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.
    opponent:
        Opponent identifier — see module docstring for valid values.
    env_kwargs:
        Extra keyword arguments forwarded to :class:`BattalionEnv`
        (e.g. ``map_width``, ``reward_weights``).

    Returns
    -------
    EvaluationResult

    Raises
    ------
    ValueError
        If *n_episodes* < 1.
    """
    if n_episodes < 1:
        raise ValueError(f"n_episodes must be >= 1, got {n_episodes}.")
    # Create a single env and reuse it for both model loading and episode runs.
    env = _make_env(opponent, seed=seed, env_kwargs=env_kwargs)
    try:
        model = PPO.load(checkpoint_path, env=env)
        result = run_episodes_with_model(
            model,
            opponent=opponent,
            n_episodes=n_episodes,
            deterministic=deterministic,
            seed=seed,
            env=env,  # reuse the already-created env; caller closes below
            env_kwargs=env_kwargs,
        )
    finally:
        env.close()
    return result

main(argv=None)

CLI entry point.

Source code in training/evaluate.py
def main(argv: Optional[list[str]] = None) -> None:
    """CLI entry point."""
    parser = argparse.ArgumentParser(
        description=(
            "Evaluate a PPO checkpoint against a chosen opponent "
            "and optionally update an Elo registry."
        ),
    )
    parser.add_argument(
        "--checkpoint",
        required=True,
        help="Path to the SB3 .zip checkpoint (extension optional).",
    )
    # ── Multi-echelon policy selection (E3.6) ────────────────────────────
    policy_group = parser.add_argument_group(
        "multi-echelon policy selection (E3.6)",
        description=(
            "Use a versioned policy from the PolicyRegistry.  "
            "Requires --policy-registry to be set.  "
            "Each flag accepts a version string (e.g. 'v2_final')."
        ),
    )
    policy_group.add_argument(
        "--policy-registry",
        metavar="PATH",
        default=None,
        help=(
            "Path to the PolicyRegistry JSON manifest.  "
            "When provided the --*-policy flags resolve versions from this "
            "registry."
        ),
    )
    policy_group.add_argument(
        "--battalion-policy",
        metavar="VERSION",
        default=None,
        help=(
            "Version of the battalion policy to load from --policy-registry.  "
            "Resolves and prints the registry entry for logging/tracking; "
            "use --checkpoint to specify the SB3 .zip checkpoint to evaluate."
        ),
    )
    policy_group.add_argument(
        "--brigade-policy",
        metavar="VERSION",
        default=None,
        help=(
            "Version of the brigade policy to load from --policy-registry.  "
            "Stored for use by the HRL evaluation pipeline."
        ),
    )
    policy_group.add_argument(
        "--division-policy",
        metavar="VERSION",
        default=None,
        help=(
            "Version of the division policy to load from --policy-registry.  "
            "Stored for use by the HRL evaluation pipeline."
        ),
    )
    parser.add_argument(
        "--opponent",
        default="scripted_l5",
        help=(
            "Opponent to evaluate against.  "
            "One of 'scripted_l1'…'scripted_l5', 'random', "
            "or a path to an SB3 .zip checkpoint.  "
            "(default: scripted_l5)"
        ),
    )
    parser.add_argument(
        "--n-episodes",
        type=int,
        default=50,
        help="Number of evaluation episodes (default: 50, minimum: 1).",
    )
    action_group = parser.add_mutually_exclusive_group()
    action_group.add_argument(
        "--deterministic",
        dest="deterministic",
        action="store_true",
        help="Use deterministic actions (default).",
    )
    action_group.add_argument(
        "--stochastic",
        dest="deterministic",
        action="store_false",
        help="Use stochastic actions instead of deterministic.",
    )
    parser.set_defaults(deterministic=True)
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Random seed for the evaluation environment.",
    )
    parser.add_argument(
        "--elo-registry",
        default=None,
        help=(
            "Path to the Elo registry JSON file.  "
            "When provided, the agent's rating is updated and persisted.  "
            "(default: no persistence)"
        ),
    )
    parser.add_argument(
        "--agent-name",
        default=None,
        help=(
            "Name used as the agent key in the Elo registry.  "
            "Defaults to the checkpoint path when not specified."
        ),
    )
    parser.add_argument(
        "--render",
        action="store_true",
        default=False,
        help=(
            "Open a pygame window and display each episode.  "
            "Requires pygame to be installed and a display to be available."
        ),
    )
    parser.add_argument(
        "--record",
        metavar="DIR",
        default=None,
        help=(
            "Save each episode trajectory as a JSON file under DIR "
            "(e.g. 'replays').  Files are named '<checkpoint>_ep<N>.json'."
        ),
    )

    args = parser.parse_args(argv)
    if args.n_episodes < 1:
        parser.error(f"--n-episodes must be >= 1, got {args.n_episodes}.")

    # ------------------------------------------------------------------
    # ------------------------------------------------------------------
    # Resolve policy versions from PolicyRegistry when --*-policy flags given
    # ------------------------------------------------------------------
    has_policy_flag = (
        args.battalion_policy is not None
        or args.brigade_policy is not None
        or args.division_policy is not None
    )
    if has_policy_flag and args.policy_registry is None:
        parser.error(
            "--policy-registry is required when using "
            "--battalion-policy / --brigade-policy / --division-policy."
        )

    if has_policy_flag:
        from training.policy_registry import PolicyRegistry  # noqa: PLC0415
        _pol_reg = PolicyRegistry(path=args.policy_registry)

        if args.battalion_policy is not None:
            try:
                _entry = _pol_reg.get("battalion", args.battalion_policy)
                print(
                    f"Battalion policy: {args.battalion_policy} \u2192 {_entry.path}"
                    + (f" (run_id={_entry.run_id})" if _entry.run_id else "")
                )
            except KeyError as exc:
                parser.error(str(exc))

        if args.brigade_policy is not None:
            try:
                _entry = _pol_reg.get("brigade", args.brigade_policy)
                print(
                    f"Brigade policy:   {args.brigade_policy} \u2192 {_entry.path}"
                    + (f" (run_id={_entry.run_id})" if _entry.run_id else "")
                )
            except KeyError as exc:
                parser.error(str(exc))

        if args.division_policy is not None:
            try:
                _entry = _pol_reg.get("division", args.division_policy)
                print(
                    f"Division policy:  {args.division_policy} \u2192 {_entry.path}"
                    + (f" (run_id={_entry.run_id})" if _entry.run_id else "")
                )
            except KeyError as exc:
                parser.error(str(exc))

    # ------------------------------------------------------------------
    # Rendered / recorded path
    # ------------------------------------------------------------------
    if args.render or args.record:
        env = _make_env(args.opponent, seed=args.seed)
        model = PPO.load(args.checkpoint, env=env)
        wins = draws = losses = 0
        record_dir = Path(args.record) if args.record else None
        try:
            for ep in range(args.n_episodes):
                ep_seed = None if args.seed is None else args.seed + ep
                recorder = None
                if record_dir is not None:
                    from envs.rendering.recorder import EpisodeRecorder  # noqa: PLC0415
                    recorder = EpisodeRecorder()

                if args.render:
                    outcome = _run_rendered_episode(
                        model,
                        env,
                        ep_seed=ep_seed,
                        deterministic=args.deterministic,
                        recorder=recorder,
                    )
                else:
                    # Record-only (no window)
                    obs, _ = env.reset(seed=ep_seed)
                    done = False
                    step_info: dict = {}
                    current_step = 0
                    if recorder is not None:
                        recorder.record_step(current_step, env.blue, env.red)  # type: ignore[arg-type]
                    while not done:
                        action, _ = model.predict(obs, deterministic=args.deterministic)
                        obs, reward, terminated, truncated, step_info = env.step(action)
                        current_step += 1
                        done = terminated or truncated
                        if recorder is not None:
                            recorder.record_step(current_step, env.blue, env.red, float(reward), step_info)  # type: ignore[arg-type]
                    outcome = _classify_outcome(step_info, env)

                if outcome == 1:
                    wins += 1
                elif outcome == -1:
                    losses += 1
                else:
                    draws += 1

                if recorder is not None and record_dir is not None:
                    ckpt_stem = Path(args.checkpoint).stem
                    save_path = record_dir / f"{ckpt_stem}_ep{ep:04d}.json"
                    recorder.save(save_path)
                    print(f"Recorded:  {save_path}")
        finally:
            env.close()

        result = EvaluationResult(
            wins=wins,
            draws=draws,
            losses=losses,
            n_episodes=args.n_episodes,
            win_rate=wins / args.n_episodes,
            draw_rate=draws / args.n_episodes,
            loss_rate=losses / args.n_episodes,
        )
    else:
        result = evaluate_detailed(
            checkpoint_path=args.checkpoint,
            n_episodes=args.n_episodes,
            deterministic=args.deterministic,
            seed=args.seed,
            opponent=args.opponent,
        )

    # Instantiate a registry for Elo computation.  When --elo-registry is
    # given we load from (and later persist to) that file; otherwise we use an
    # in-memory registry (path=None) so no file is created without explicit
    # opt-in.
    agent_name = args.agent_name or args.checkpoint
    registry = EloRegistry(args.elo_registry)  # None → in-memory

    opp_elo = registry.get_rating(args.opponent)
    print(f"Opponent:  {args.opponent} (Elo: {opp_elo:.0f})")
    print(
        f"Win rate:  {result.win_rate:.2%} "
        f"({result.wins}W / {result.draws}D / {result.losses}L "
        f"in {result.n_episodes} episodes)"
    )

    old_rating = registry.get_rating(agent_name)
    # outcome score: win=1, draw=0.5, loss=0
    outcome = (result.wins + 0.5 * result.draws) / result.n_episodes
    delta = registry.update(
        agent=agent_name,
        opponent=args.opponent,
        outcome=outcome,
        n_games=result.n_episodes,
    )
    new_rating = registry.get_rating(agent_name)
    print(
        f"Elo:       {old_rating:.1f}{new_rating:.1f} "
        f"(Δ {delta:+.1f})"
    )

    if args.elo_registry is not None:
        registry.save()
        print(f"Registry:  saved to {args.elo_registry}")

run_episodes_with_model(model, opponent='scripted_l5', n_episodes=50, deterministic=True, seed=None, env=None, env_kwargs=None)

Run evaluation episodes using an already-loaded model object.

This is useful for in-training callbacks that have direct access to a :class:~stable_baselines3.PPO model without needing to save and reload a checkpoint file.

Parameters:

Name Type Description Default
model Any

Any object with a predict(obs, deterministic) method (e.g. an SB3 PPO instance).

required
opponent str

Opponent identifier — see module docstring for valid values.

'scripted_l5'
n_episodes int

Number of episodes to run (must be ≥ 1).

50
deterministic bool

Whether the policy acts deterministically.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided. Also used to seed :class:_RandomPolicy when opponent="random".

None
env Optional[BattalionEnv]

Pre-built :class:BattalionEnv to reuse. When provided the caller owns the environment and is responsible for closing it. When None, a new environment is created and closed automatically.

None
env_kwargs Optional[dict]

Extra keyword arguments forwarded to :class:BattalionEnv when creating a new environment (ignored when env is provided).

None

Returns:

Type Description
EvaluationResult

Raises:

Type Description
ValueError

If n_episodes < 1.

Source code in training/evaluate.py
def run_episodes_with_model(
    model: Any,
    opponent: str = "scripted_l5",
    n_episodes: int = 50,
    deterministic: bool = True,
    seed: Optional[int] = None,
    env: Optional[BattalionEnv] = None,
    env_kwargs: Optional[dict] = None,
) -> EvaluationResult:
    """Run evaluation episodes using an already-loaded model object.

    This is useful for in-training callbacks that have direct access to a
    :class:`~stable_baselines3.PPO` model without needing to save and reload
    a checkpoint file.

    Parameters
    ----------
    model:
        Any object with a ``predict(obs, deterministic)`` method (e.g. an
        SB3 ``PPO`` instance).
    opponent:
        Opponent identifier — see module docstring for valid values.
    n_episodes:
        Number of episodes to run (must be ≥ 1).
    deterministic:
        Whether the policy acts deterministically.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.
        Also used to seed :class:`_RandomPolicy` when ``opponent="random"``.
    env:
        Pre-built :class:`BattalionEnv` to reuse.  When provided the caller
        owns the environment and is responsible for closing it.  When
        ``None``, a new environment is created and closed automatically.
    env_kwargs:
        Extra keyword arguments forwarded to :class:`BattalionEnv` when
        creating a new environment (ignored when *env* is provided).

    Returns
    -------
    EvaluationResult

    Raises
    ------
    ValueError
        If *n_episodes* < 1.
    """
    if n_episodes < 1:
        raise ValueError(f"n_episodes must be >= 1, got {n_episodes}.")

    owns_env = env is None
    active_env: BattalionEnv = (
        env if env is not None
        else _make_env(opponent, seed=seed, env_kwargs=env_kwargs)
    )
    wins = draws = losses = 0

    for ep in range(n_episodes):
        ep_seed = None if seed is None else seed + ep
        obs, _ = active_env.reset(seed=ep_seed)
        done = False
        while not done:
            action, _ = model.predict(obs, deterministic=deterministic)
            obs, _reward, terminated, truncated, info = active_env.step(action)
            done = terminated or truncated

        outcome = _classify_outcome(info, active_env)
        if outcome == 1:
            wins += 1
        elif outcome == -1:
            losses += 1
        else:
            draws += 1

    if owns_env:
        active_env.close()
    return EvaluationResult(
        wins=wins,
        draws=draws,
        losses=losses,
        n_episodes=n_episodes,
        win_rate=wins / n_episodes,
        draw_rate=draws / n_episodes,
        loss_rate=losses / n_episodes,
    )

training.evaluate.EvaluationResult

Bases: NamedTuple

Structured result from an evaluation run.

Attributes:

Name Type Description
wins int

Number of episodes Blue won.

draws int

Number of episodes that ended as a draw (both sides lost or timeout).

losses int

Number of episodes Blue lost.

n_episodes int

Total episodes evaluated (wins + draws + losses).

win_rate float

wins / n_episodes.

draw_rate float

draws / n_episodes.

loss_rate float

losses / n_episodes.

Source code in training/evaluate.py
class EvaluationResult(NamedTuple):
    """Structured result from an evaluation run.

    Attributes
    ----------
    wins:
        Number of episodes Blue won.
    draws:
        Number of episodes that ended as a draw (both sides lost or timeout).
    losses:
        Number of episodes Blue lost.
    n_episodes:
        Total episodes evaluated (``wins + draws + losses``).
    win_rate:
        ``wins / n_episodes``.
    draw_rate:
        ``draws / n_episodes``.
    loss_rate:
        ``losses / n_episodes``.
    """

    wins: int
    draws: int
    losses: int
    n_episodes: int
    win_rate: float
    draw_rate: float
    loss_rate: float

training.evaluate.evaluate(checkpoint_path, n_episodes=50, deterministic=True, seed=None, opponent='scripted_l5')

Load a checkpoint, run n_episodes, and return the Blue win rate.

This is a thin wrapper around :func:evaluate_detailed kept for backward compatibility.

Parameters:

Name Type Description Default
checkpoint_path str

Path to the .zip checkpoint (extension may be omitted).

required
n_episodes int

Number of evaluation episodes (must be ≥ 1).

50
deterministic bool

Whether the policy acts deterministically.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided.

None
opponent str

Opponent identifier — see module docstring for valid values. Defaults to "scripted_l5" (full-combat scripted Red) to match the previous behaviour.

'scripted_l5'

Returns:

Type Description
float

Win rate in [0, 1].

Raises:

Type Description
ValueError

If n_episodes is less than 1.

Source code in training/evaluate.py
def evaluate(
    checkpoint_path: str,
    n_episodes: int = 50,
    deterministic: bool = True,
    seed: Optional[int] = None,
    opponent: str = "scripted_l5",
) -> float:
    """Load a checkpoint, run *n_episodes*, and return the Blue win rate.

    This is a thin wrapper around :func:`evaluate_detailed` kept for
    backward compatibility.

    Parameters
    ----------
    checkpoint_path:
        Path to the ``.zip`` checkpoint (extension may be omitted).
    n_episodes:
        Number of evaluation episodes (must be ≥ 1).
    deterministic:
        Whether the policy acts deterministically.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.
    opponent:
        Opponent identifier — see module docstring for valid values.
        Defaults to ``"scripted_l5"`` (full-combat scripted Red) to match
        the previous behaviour.

    Returns
    -------
    float
        Win rate in ``[0, 1]``.

    Raises
    ------
    ValueError
        If *n_episodes* is less than 1.
    """
    return evaluate_detailed(
        checkpoint_path,
        n_episodes=n_episodes,
        deterministic=deterministic,
        seed=seed,
        opponent=opponent,
    ).win_rate

training.evaluate.evaluate_detailed(checkpoint_path, n_episodes=50, deterministic=True, seed=None, opponent='scripted_l5', env_kwargs=None)

Load a checkpoint and run n_episodes, returning a full result struct.

Parameters:

Name Type Description Default
checkpoint_path str

Path to the .zip checkpoint (extension may be omitted).

required
n_episodes int

Number of evaluation episodes (must be ≥ 1).

50
deterministic bool

Whether the policy acts deterministically.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided.

None
opponent str

Opponent identifier — see module docstring for valid values.

'scripted_l5'
env_kwargs Optional[dict]

Extra keyword arguments forwarded to :class:BattalionEnv (e.g. map_width, reward_weights).

None

Returns:

Type Description
EvaluationResult

Raises:

Type Description
ValueError

If n_episodes < 1.

Source code in training/evaluate.py
def evaluate_detailed(
    checkpoint_path: str,
    n_episodes: int = 50,
    deterministic: bool = True,
    seed: Optional[int] = None,
    opponent: str = "scripted_l5",
    env_kwargs: Optional[dict] = None,
) -> EvaluationResult:
    """Load a checkpoint and run *n_episodes*, returning a full result struct.

    Parameters
    ----------
    checkpoint_path:
        Path to the ``.zip`` checkpoint (extension may be omitted).
    n_episodes:
        Number of evaluation episodes (must be ≥ 1).
    deterministic:
        Whether the policy acts deterministically.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.
    opponent:
        Opponent identifier — see module docstring for valid values.
    env_kwargs:
        Extra keyword arguments forwarded to :class:`BattalionEnv`
        (e.g. ``map_width``, ``reward_weights``).

    Returns
    -------
    EvaluationResult

    Raises
    ------
    ValueError
        If *n_episodes* < 1.
    """
    if n_episodes < 1:
        raise ValueError(f"n_episodes must be >= 1, got {n_episodes}.")
    # Create a single env and reuse it for both model loading and episode runs.
    env = _make_env(opponent, seed=seed, env_kwargs=env_kwargs)
    try:
        model = PPO.load(checkpoint_path, env=env)
        result = run_episodes_with_model(
            model,
            opponent=opponent,
            n_episodes=n_episodes,
            deterministic=deterministic,
            seed=seed,
            env=env,  # reuse the already-created env; caller closes below
            env_kwargs=env_kwargs,
        )
    finally:
        env.close()
    return result

training.evaluate.run_episodes_with_model(model, opponent='scripted_l5', n_episodes=50, deterministic=True, seed=None, env=None, env_kwargs=None)

Run evaluation episodes using an already-loaded model object.

This is useful for in-training callbacks that have direct access to a :class:~stable_baselines3.PPO model without needing to save and reload a checkpoint file.

Parameters:

Name Type Description Default
model Any

Any object with a predict(obs, deterministic) method (e.g. an SB3 PPO instance).

required
opponent str

Opponent identifier — see module docstring for valid values.

'scripted_l5'
n_episodes int

Number of episodes to run (must be ≥ 1).

50
deterministic bool

Whether the policy acts deterministically.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided. Also used to seed :class:_RandomPolicy when opponent="random".

None
env Optional[BattalionEnv]

Pre-built :class:BattalionEnv to reuse. When provided the caller owns the environment and is responsible for closing it. When None, a new environment is created and closed automatically.

None
env_kwargs Optional[dict]

Extra keyword arguments forwarded to :class:BattalionEnv when creating a new environment (ignored when env is provided).

None

Returns:

Type Description
EvaluationResult

Raises:

Type Description
ValueError

If n_episodes < 1.

Source code in training/evaluate.py
def run_episodes_with_model(
    model: Any,
    opponent: str = "scripted_l5",
    n_episodes: int = 50,
    deterministic: bool = True,
    seed: Optional[int] = None,
    env: Optional[BattalionEnv] = None,
    env_kwargs: Optional[dict] = None,
) -> EvaluationResult:
    """Run evaluation episodes using an already-loaded model object.

    This is useful for in-training callbacks that have direct access to a
    :class:`~stable_baselines3.PPO` model without needing to save and reload
    a checkpoint file.

    Parameters
    ----------
    model:
        Any object with a ``predict(obs, deterministic)`` method (e.g. an
        SB3 ``PPO`` instance).
    opponent:
        Opponent identifier — see module docstring for valid values.
    n_episodes:
        Number of episodes to run (must be ≥ 1).
    deterministic:
        Whether the policy acts deterministically.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.
        Also used to seed :class:`_RandomPolicy` when ``opponent="random"``.
    env:
        Pre-built :class:`BattalionEnv` to reuse.  When provided the caller
        owns the environment and is responsible for closing it.  When
        ``None``, a new environment is created and closed automatically.
    env_kwargs:
        Extra keyword arguments forwarded to :class:`BattalionEnv` when
        creating a new environment (ignored when *env* is provided).

    Returns
    -------
    EvaluationResult

    Raises
    ------
    ValueError
        If *n_episodes* < 1.
    """
    if n_episodes < 1:
        raise ValueError(f"n_episodes must be >= 1, got {n_episodes}.")

    owns_env = env is None
    active_env: BattalionEnv = (
        env if env is not None
        else _make_env(opponent, seed=seed, env_kwargs=env_kwargs)
    )
    wins = draws = losses = 0

    for ep in range(n_episodes):
        ep_seed = None if seed is None else seed + ep
        obs, _ = active_env.reset(seed=ep_seed)
        done = False
        while not done:
            action, _ = model.predict(obs, deterministic=deterministic)
            obs, _reward, terminated, truncated, info = active_env.step(action)
            done = terminated or truncated

        outcome = _classify_outcome(info, active_env)
        if outcome == 1:
            wins += 1
        elif outcome == -1:
            losses += 1
        else:
            draws += 1

    if owns_env:
        active_env.close()
    return EvaluationResult(
        wins=wins,
        draws=draws,
        losses=losses,
        n_episodes=n_episodes,
        win_rate=wins / n_episodes,
        draw_rate=draws / n_episodes,
        loss_rate=losses / n_episodes,
    )

Self-play

training.self_play

Self-play training utilities.

Provides:

  • :class:OpponentPool — a fixed-size pool of frozen policy snapshots that can be sampled uniformly as opponents during self-play training.
  • :class:SelfPlayCallback — SB3 callback that periodically snapshots the current policy into the pool and swaps the Red opponent in the vectorized training environment.
  • :class:WinRateVsPoolCallback — SB3 callback that evaluates the current policy against a random opponent from the pool and logs the win rate to W&B.
  • :func:evaluate_vs_pool — standalone helper that runs n evaluation episodes against an opponent sampled from the pool and returns the win rate.

Multi-agent (MAPPO) additions:

  • :class:TeamOpponentPool — fixed-size pool of frozen :class:~models.mappo_policy.MAPPOPolicy snapshots for team self-play.
  • :func:evaluate_team_vs_pool — evaluate a MAPPO policy (Blue) against a frozen team opponent (Red) and return the win rate.
  • :func:nash_exploitability_proxy — estimate exploitability as max(opp_win_rates) − mean(opp_win_rates) across all pool members.

Typical usage::

from training.self_play import OpponentPool, SelfPlayCallback, WinRateVsPoolCallback
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from envs.battalion_env import BattalionEnv

pool = OpponentPool(pool_dir="checkpoints/pool", max_size=10)
env = make_vec_env(BattalionEnv, n_envs=8)

model = PPO("MlpPolicy", env)
sp_cb = SelfPlayCallback(pool=pool, snapshot_freq=50_000, vec_env=env)
wr_cb = WinRateVsPoolCallback(pool=pool, eval_freq=50_000)

model.learn(total_timesteps=1_000_000, callback=[sp_cb, wr_cb])

OpponentPool

Fixed-size pool of frozen PPO policy snapshots.

Snapshots are stored as Stable-Baselines3 .zip files under pool_dir. The pool keeps at most max_size snapshots; when full, the oldest snapshot is evicted to make room for the newest one.

Parameters:

Name Type Description Default
pool_dir str | Path

Directory where snapshot .zip files are persisted. Created on first :meth:add if it does not already exist.

required
max_size int

Maximum number of snapshots to retain (default 10).

10

Attributes:

Name Type Description
pool_dir Path

Resolved path of the snapshot directory.

max_size int

Maximum number of snapshots retained in the pool.

Source code in training/self_play.py
class OpponentPool:
    """Fixed-size pool of frozen PPO policy snapshots.

    Snapshots are stored as Stable-Baselines3 ``.zip`` files under
    *pool_dir*.  The pool keeps at most *max_size* snapshots; when full,
    the oldest snapshot is evicted to make room for the newest one.

    Parameters
    ----------
    pool_dir:
        Directory where snapshot ``.zip`` files are persisted.  Created
        on first :meth:`add` if it does not already exist.
    max_size:
        Maximum number of snapshots to retain (default 10).

    Attributes
    ----------
    pool_dir : Path
        Resolved path of the snapshot directory.
    max_size : int
        Maximum number of snapshots retained in the pool.
    """

    def __init__(self, pool_dir: str | Path, max_size: int = 10) -> None:
        if max_size < 1:
            raise ValueError(f"max_size must be >= 1, got {max_size}")
        self.pool_dir = Path(pool_dir)
        self.max_size = int(max_size)
        # Ordered list of snapshot file paths (oldest first).
        self._snapshots: List[Path] = []
        # Shared RNG instance used for uniform sampling when no external RNG is given.
        self._rng: np.random.Generator = np.random.default_rng()
        # Restore existing snapshots from disk if the directory exists.
        self._reload_from_disk()

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    def add(self, model: PPO, version: int) -> Path:
        """Save *model* as a new snapshot and add it to the pool.

        If the pool already contains *max_size* snapshots the oldest is
        removed from disk before saving the new one.

        Parameters
        ----------
        model:
            The current PPO model to snapshot.
        version:
            Monotonically increasing version number embedded in the
            snapshot file name for traceability.

        Returns
        -------
        Path
            Path of the newly saved snapshot file (including ``.zip``).
        """
        self.pool_dir.mkdir(parents=True, exist_ok=True)
        snapshot_path = self.pool_dir / f"snapshot_v{version:06d}"
        model.save(str(snapshot_path))
        full_path = snapshot_path.with_suffix(".zip")

        # Evict oldest if at capacity.
        while len(self._snapshots) >= self.max_size:
            oldest = self._snapshots.pop(0)
            try:
                oldest.unlink(missing_ok=True)
                log.debug("Evicted snapshot %s", oldest)
            except OSError as exc:
                log.warning("Failed to evict snapshot %s: %s", oldest, exc)

        self._snapshots.append(full_path)
        log.info("Saved snapshot %s (pool size %d/%d)", full_path, len(self._snapshots), self.max_size)
        return full_path

    def sample(self, rng: Optional[np.random.Generator] = None) -> Optional[PPO]:
        """Load and return a uniformly sampled snapshot as a PPO model.

        Parameters
        ----------
        rng:
            Optional NumPy random generator for reproducible sampling.
            When ``None``, the pool's internal shared RNG instance is used
            (seeded once at pool construction time).

        Returns
        -------
        PPO or None
            A loaded PPO model, or ``None`` if the pool is empty.
        """
        if not self._snapshots:
            return None
        _rng = rng if rng is not None else self._rng
        idx = int(_rng.integers(0, len(self._snapshots)))
        path = self._snapshots[idx]
        try:
            model = PPO.load(str(path))
            log.debug("Sampled snapshot %s", path)
            return model
        except Exception as exc:
            log.warning("Failed to load snapshot %s: %s", path, exc)
            return None

    def sample_latest(self) -> Optional[PPO]:
        """Load and return the most recently added snapshot.

        Returns
        -------
        PPO or None
            The latest PPO model, or ``None`` if the pool is empty.
        """
        if not self._snapshots:
            return None
        path = self._snapshots[-1]
        try:
            return PPO.load(str(path))
        except Exception as exc:
            log.warning("Failed to load latest snapshot %s: %s", path, exc)
            return None

    @property
    def size(self) -> int:
        """Current number of snapshots in the pool."""
        return len(self._snapshots)

    @property
    def snapshot_paths(self) -> List[Path]:
        """Ordered list of snapshot paths (oldest first, read-only copy)."""
        return list(self._snapshots)

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _reload_from_disk(self) -> None:
        """Populate *_snapshots* from existing files in *pool_dir*.

        Files beyond the pool's *max_size* capacity (the oldest ones) are
        deleted from disk to enforce the pool invariant across restarts.
        """
        if not self.pool_dir.exists():
            return
        existing = sorted(self.pool_dir.glob("snapshot_v*.zip"))
        # Delete excess (oldest) snapshots so on-disk state matches the pool
        # invariant of keeping at most *max_size* files.
        excess = existing[: max(0, len(existing) - self.max_size)]  # guard for len <= max_size
        for path in excess:
            try:
                path.unlink(missing_ok=True)
                log.debug("_reload_from_disk: deleted excess snapshot %s", path)
            except OSError as exc:
                log.warning("_reload_from_disk: failed to delete %s: %s", path, exc)
        self._snapshots = existing[max(0, len(existing) - self.max_size) :]
        if self._snapshots:
            log.info(
                "Restored %d snapshot(s) from %s", len(self._snapshots), self.pool_dir
            )

size property

Current number of snapshots in the pool.

snapshot_paths property

Ordered list of snapshot paths (oldest first, read-only copy).

add(model, version)

Save model as a new snapshot and add it to the pool.

If the pool already contains max_size snapshots the oldest is removed from disk before saving the new one.

Parameters:

Name Type Description Default
model PPO

The current PPO model to snapshot.

required
version int

Monotonically increasing version number embedded in the snapshot file name for traceability.

required

Returns:

Type Description
Path

Path of the newly saved snapshot file (including .zip).

Source code in training/self_play.py
def add(self, model: PPO, version: int) -> Path:
    """Save *model* as a new snapshot and add it to the pool.

    If the pool already contains *max_size* snapshots the oldest is
    removed from disk before saving the new one.

    Parameters
    ----------
    model:
        The current PPO model to snapshot.
    version:
        Monotonically increasing version number embedded in the
        snapshot file name for traceability.

    Returns
    -------
    Path
        Path of the newly saved snapshot file (including ``.zip``).
    """
    self.pool_dir.mkdir(parents=True, exist_ok=True)
    snapshot_path = self.pool_dir / f"snapshot_v{version:06d}"
    model.save(str(snapshot_path))
    full_path = snapshot_path.with_suffix(".zip")

    # Evict oldest if at capacity.
    while len(self._snapshots) >= self.max_size:
        oldest = self._snapshots.pop(0)
        try:
            oldest.unlink(missing_ok=True)
            log.debug("Evicted snapshot %s", oldest)
        except OSError as exc:
            log.warning("Failed to evict snapshot %s: %s", oldest, exc)

    self._snapshots.append(full_path)
    log.info("Saved snapshot %s (pool size %d/%d)", full_path, len(self._snapshots), self.max_size)
    return full_path

sample(rng=None)

Load and return a uniformly sampled snapshot as a PPO model.

Parameters:

Name Type Description Default
rng Optional[Generator]

Optional NumPy random generator for reproducible sampling. When None, the pool's internal shared RNG instance is used (seeded once at pool construction time).

None

Returns:

Type Description
PPO or None

A loaded PPO model, or None if the pool is empty.

Source code in training/self_play.py
def sample(self, rng: Optional[np.random.Generator] = None) -> Optional[PPO]:
    """Load and return a uniformly sampled snapshot as a PPO model.

    Parameters
    ----------
    rng:
        Optional NumPy random generator for reproducible sampling.
        When ``None``, the pool's internal shared RNG instance is used
        (seeded once at pool construction time).

    Returns
    -------
    PPO or None
        A loaded PPO model, or ``None`` if the pool is empty.
    """
    if not self._snapshots:
        return None
    _rng = rng if rng is not None else self._rng
    idx = int(_rng.integers(0, len(self._snapshots)))
    path = self._snapshots[idx]
    try:
        model = PPO.load(str(path))
        log.debug("Sampled snapshot %s", path)
        return model
    except Exception as exc:
        log.warning("Failed to load snapshot %s: %s", path, exc)
        return None

sample_latest()

Load and return the most recently added snapshot.

Returns:

Type Description
PPO or None

The latest PPO model, or None if the pool is empty.

Source code in training/self_play.py
def sample_latest(self) -> Optional[PPO]:
    """Load and return the most recently added snapshot.

    Returns
    -------
    PPO or None
        The latest PPO model, or ``None`` if the pool is empty.
    """
    if not self._snapshots:
        return None
    path = self._snapshots[-1]
    try:
        return PPO.load(str(path))
    except Exception as exc:
        log.warning("Failed to load latest snapshot %s: %s", path, exc)
        return None

SelfPlayCallback

Bases: BaseCallback

Periodically snapshots the current policy and updates the Red opponent.

Every snapshot_freq environment steps the current model is saved to the :class:OpponentPool. If the pool contains at least one snapshot, a uniformly sampled opponent is loaded and injected into each environment in vec_env via :meth:~envs.battalion_env.BattalionEnv.set_red_policy.

Parameters:

Name Type Description Default
pool OpponentPool

The :class:OpponentPool to save snapshots into.

required
snapshot_freq int

How often (in environment steps) to take a snapshot.

50000
vec_env Optional[VecEnv]

The vectorized training environment whose Red opponents should be updated. When None, the callback uses self.training_env (set automatically by SB3 during model.learn()).

None
verbose int

Verbosity level (0 = silent, 1 = info).

0
Source code in training/self_play.py
class SelfPlayCallback(BaseCallback):
    """Periodically snapshots the current policy and updates the Red opponent.

    Every *snapshot_freq* environment steps the current model is saved to
    the :class:`OpponentPool`.  If the pool contains at least one snapshot,
    a uniformly sampled opponent is loaded and injected into each
    environment in *vec_env* via :meth:`~envs.battalion_env.BattalionEnv.set_red_policy`.

    Parameters
    ----------
    pool:
        The :class:`OpponentPool` to save snapshots into.
    snapshot_freq:
        How often (in environment steps) to take a snapshot.
    vec_env:
        The vectorized training environment whose Red opponents should be
        updated.  When ``None``, the callback uses
        ``self.training_env`` (set automatically by SB3 during
        ``model.learn()``).
    verbose:
        Verbosity level (0 = silent, 1 = info).
    """

    def __init__(
        self,
        pool: OpponentPool,
        snapshot_freq: int = 50_000,
        vec_env: Optional[VecEnv] = None,
        verbose: int = 0,
        manifest: Optional[CheckpointManifest] = None,
        seed: int = 0,
        curriculum_level: int = 5,
        run_id: Optional[str] = None,
        config_hash: str = "",
    ) -> None:
        super().__init__(verbose)
        if int(snapshot_freq) < 1:
            raise ValueError(f"snapshot_freq must be >= 1, got {snapshot_freq}")
        self.pool = pool
        self.snapshot_freq = int(snapshot_freq)
        self._vec_env = vec_env
        # Initialize version counter from any snapshots already in the pool so
        # that a training restart doesn't overwrite existing snapshot files.
        self._version: int = _max_version_in_pool(pool)
        # Provenance manifest — optional; when set, every snapshot is indexed.
        self._sp_manifest = manifest
        self._sp_seed = int(seed)
        self._sp_curriculum_level = int(curriculum_level)
        self._sp_run_id = run_id
        self._sp_config_hash = str(config_hash)

    def _on_step(self) -> bool:
        if self.num_timesteps % self.snapshot_freq == 0 and self.num_timesteps > 0:
            self._take_snapshot_and_update()
        return True

    def _take_snapshot_and_update(self) -> None:
        """Save current model to pool and refresh all Red opponents."""
        self._version += 1
        snapshot_path = self.pool.add(self.model, self._version)

        if self._sp_manifest is not None:
            self._sp_manifest.register(
                snapshot_path,
                artifact_type="self_play_snapshot",
                seed=self._sp_seed,
                curriculum_level=self._sp_curriculum_level,
                run_id=self._sp_run_id,
                config_hash=self._sp_config_hash,
                step=int(self.num_timesteps),
            )

        opponent = self.pool.sample()
        if opponent is None:
            return

        env = self._vec_env if self._vec_env is not None else self.training_env
        if env is None:
            log.warning("SelfPlayCallback: no environment available to update Red policy.")
            return

        # Propagate to every sub-environment.
        for env_instance in _iter_envs(env):
            env_instance.set_red_policy(opponent)

        if self.verbose >= 1:
            log.info(
                "SelfPlayCallback: snapshot v%d saved; Red policy updated (pool=%d).",
                self._version,
                self.pool.size,
            )

TeamOpponentPool

Fixed-size pool of frozen :class:~models.mappo_policy.MAPPOPolicy snapshots.

Snapshots are stored as PyTorch .pt files under pool_dir. Each file contains the policy state_dict plus the constructor kwargs (obs_dim, action_dim, state_dim, n_agents, share_parameters) needed to reconstruct the policy at load time. The pool keeps at most max_size snapshots; when full, the oldest is evicted to make room for the newest.

Parameters:

Name Type Description Default
pool_dir str | Path

Directory where snapshot .pt files are persisted. Created on first :meth:add if it does not already exist.

required
max_size int

Maximum number of snapshots to retain (default 10).

10
Source code in training/self_play.py
class TeamOpponentPool:
    """Fixed-size pool of frozen :class:`~models.mappo_policy.MAPPOPolicy` snapshots.

    Snapshots are stored as PyTorch ``.pt`` files under *pool_dir*.  Each
    file contains the policy ``state_dict`` plus the constructor kwargs
    (``obs_dim``, ``action_dim``, ``state_dim``, ``n_agents``,
    ``share_parameters``) needed to reconstruct the policy at load time.
    The pool keeps at most *max_size* snapshots; when full, the oldest is
    evicted to make room for the newest.

    Parameters
    ----------
    pool_dir:
        Directory where snapshot ``.pt`` files are persisted.  Created on
        first :meth:`add` if it does not already exist.
    max_size:
        Maximum number of snapshots to retain (default 10).
    """

    def __init__(self, pool_dir: str | Path, max_size: int = 10) -> None:
        if max_size < 1:
            raise ValueError(f"max_size must be >= 1, got {max_size}")
        self.pool_dir = Path(pool_dir)
        self.max_size = int(max_size)
        self._snapshots: List[Path] = []
        self._rng: np.random.Generator = np.random.default_rng()
        self._reload_from_disk()

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    def add(self, policy: "MAPPOPolicy", version: int) -> Path:
        """Save *policy* as a new snapshot and add it to the pool.

        The snapshot stores both the ``state_dict`` and the constructor
        kwargs so the policy can be fully reconstructed later.

        Parameters
        ----------
        policy:
            The :class:`~models.mappo_policy.MAPPOPolicy` to snapshot.
        version:
            Monotonically increasing version number embedded in the file
            name for traceability.

        Returns
        -------
        Path
            Path of the newly saved snapshot file.
        """
        import torch  # local import to avoid hard dependency at module load

        self.pool_dir.mkdir(parents=True, exist_ok=True)
        snapshot_path = self.pool_dir / f"team_snapshot_v{version:06d}.pt"

        torch.save(
            {
                "state_dict": policy.state_dict(),
                "kwargs": {
                    "obs_dim": policy.obs_dim,
                    "action_dim": policy.action_dim,
                    "state_dim": policy.state_dim,
                    "n_agents": policy.n_agents,
                    "share_parameters": policy.share_parameters,
                    "actor_hidden_sizes": policy.actor_hidden_sizes,
                    "critic_hidden_sizes": policy.critic_hidden_sizes,
                },
            },
            snapshot_path,
        )

        # Evict oldest if at capacity.
        while len(self._snapshots) >= self.max_size:
            oldest = self._snapshots.pop(0)
            try:
                oldest.unlink(missing_ok=True)
                log.debug("TeamOpponentPool: evicted snapshot %s", oldest)
            except OSError as exc:
                log.warning("TeamOpponentPool: failed to evict %s: %s", oldest, exc)

        self._snapshots.append(snapshot_path)
        log.info(
            "TeamOpponentPool: saved snapshot %s (pool=%d/%d)",
            snapshot_path,
            len(self._snapshots),
            self.max_size,
        )
        return snapshot_path

    def sample(self, rng: Optional[np.random.Generator] = None, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
        """Load and return a uniformly sampled snapshot.

        Parameters
        ----------
        rng:
            Optional NumPy random generator for reproducible sampling.
            When ``None``, the pool's internal RNG is used.
        device:
            Optional PyTorch device string (e.g. ``"cuda:0"``).  When
            provided the loaded policy is moved to *device* before being
            returned.  When ``None`` the policy stays on CPU.

        Returns
        -------
        MAPPOPolicy or None
            A loaded policy in evaluation mode, or ``None`` if the pool is
            empty or loading fails.
        """
        if not self._snapshots:
            return None
        _rng = rng if rng is not None else self._rng
        idx = int(_rng.integers(0, len(self._snapshots)))
        return self._load_snapshot(self._snapshots[idx], device=device)

    def sample_latest(self, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
        """Load and return the most recently added snapshot.

        Parameters
        ----------
        device:
            Optional PyTorch device string.  When provided the loaded
            policy is moved to *device* before being returned.

        Returns
        -------
        MAPPOPolicy or None
            The latest policy, or ``None`` if the pool is empty.
        """
        if not self._snapshots:
            return None
        return self._load_snapshot(self._snapshots[-1], device=device)

    @property
    def size(self) -> int:
        """Current number of snapshots in the pool."""
        return len(self._snapshots)

    @property
    def snapshot_paths(self) -> List[Path]:
        """Ordered list of snapshot paths (oldest first, read-only copy)."""
        return list(self._snapshots)

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _load_snapshot(self, path: Path, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
        """Load a :class:`~models.mappo_policy.MAPPOPolicy` from *path*.

        Parameters
        ----------
        path:
            Path to the ``.pt`` snapshot file.
        device:
            Optional PyTorch device string.  When provided the loaded
            policy is moved to *device* after loading (which always uses
            CPU as the intermediate map_location for safety).
        """
        import torch  # local import
        from models.mappo_policy import MAPPOPolicy

        try:
            # weights_only=True restricts deserialization to safe tensor/primitive
            # types, mitigating arbitrary-code-execution risk from tampered files.
            data = torch.load(str(path), map_location="cpu", weights_only=True)
            policy = MAPPOPolicy(**data["kwargs"])
            policy.load_state_dict(data["state_dict"])
            if device is not None:
                policy = policy.to(device)
            policy.eval()
            actual_device = next(policy.parameters()).device
            log.debug("TeamOpponentPool: loaded snapshot %s (device=%s)", path, actual_device)
            return policy
        except Exception as exc:
            log.warning("TeamOpponentPool: failed to load %s: %s", path, exc)
            return None

    def _reload_from_disk(self) -> None:
        """Populate *_snapshots* from existing files in *pool_dir*."""
        if not self.pool_dir.exists():
            return
        existing = sorted(self.pool_dir.glob("team_snapshot_v*.pt"))
        excess = existing[: max(0, len(existing) - self.max_size)]
        for path in excess:
            try:
                path.unlink(missing_ok=True)
                log.debug("TeamOpponentPool._reload_from_disk: deleted excess %s", path)
            except OSError as exc:
                log.warning(
                    "TeamOpponentPool._reload_from_disk: failed to delete %s: %s", path, exc
                )
        self._snapshots = existing[max(0, len(existing) - self.max_size) :]
        if self._snapshots:
            log.info(
                "TeamOpponentPool: restored %d snapshot(s) from %s",
                len(self._snapshots),
                self.pool_dir,
            )

size property

Current number of snapshots in the pool.

snapshot_paths property

Ordered list of snapshot paths (oldest first, read-only copy).

add(policy, version)

Save policy as a new snapshot and add it to the pool.

The snapshot stores both the state_dict and the constructor kwargs so the policy can be fully reconstructed later.

Parameters:

Name Type Description Default
policy 'MAPPOPolicy'

The :class:~models.mappo_policy.MAPPOPolicy to snapshot.

required
version int

Monotonically increasing version number embedded in the file name for traceability.

required

Returns:

Type Description
Path

Path of the newly saved snapshot file.

Source code in training/self_play.py
def add(self, policy: "MAPPOPolicy", version: int) -> Path:
    """Save *policy* as a new snapshot and add it to the pool.

    The snapshot stores both the ``state_dict`` and the constructor
    kwargs so the policy can be fully reconstructed later.

    Parameters
    ----------
    policy:
        The :class:`~models.mappo_policy.MAPPOPolicy` to snapshot.
    version:
        Monotonically increasing version number embedded in the file
        name for traceability.

    Returns
    -------
    Path
        Path of the newly saved snapshot file.
    """
    import torch  # local import to avoid hard dependency at module load

    self.pool_dir.mkdir(parents=True, exist_ok=True)
    snapshot_path = self.pool_dir / f"team_snapshot_v{version:06d}.pt"

    torch.save(
        {
            "state_dict": policy.state_dict(),
            "kwargs": {
                "obs_dim": policy.obs_dim,
                "action_dim": policy.action_dim,
                "state_dim": policy.state_dim,
                "n_agents": policy.n_agents,
                "share_parameters": policy.share_parameters,
                "actor_hidden_sizes": policy.actor_hidden_sizes,
                "critic_hidden_sizes": policy.critic_hidden_sizes,
            },
        },
        snapshot_path,
    )

    # Evict oldest if at capacity.
    while len(self._snapshots) >= self.max_size:
        oldest = self._snapshots.pop(0)
        try:
            oldest.unlink(missing_ok=True)
            log.debug("TeamOpponentPool: evicted snapshot %s", oldest)
        except OSError as exc:
            log.warning("TeamOpponentPool: failed to evict %s: %s", oldest, exc)

    self._snapshots.append(snapshot_path)
    log.info(
        "TeamOpponentPool: saved snapshot %s (pool=%d/%d)",
        snapshot_path,
        len(self._snapshots),
        self.max_size,
    )
    return snapshot_path

sample(rng=None, device=None)

Load and return a uniformly sampled snapshot.

Parameters:

Name Type Description Default
rng Optional[Generator]

Optional NumPy random generator for reproducible sampling. When None, the pool's internal RNG is used.

None
device Optional[str]

Optional PyTorch device string (e.g. "cuda:0"). When provided the loaded policy is moved to device before being returned. When None the policy stays on CPU.

None

Returns:

Type Description
MAPPOPolicy or None

A loaded policy in evaluation mode, or None if the pool is empty or loading fails.

Source code in training/self_play.py
def sample(self, rng: Optional[np.random.Generator] = None, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
    """Load and return a uniformly sampled snapshot.

    Parameters
    ----------
    rng:
        Optional NumPy random generator for reproducible sampling.
        When ``None``, the pool's internal RNG is used.
    device:
        Optional PyTorch device string (e.g. ``"cuda:0"``).  When
        provided the loaded policy is moved to *device* before being
        returned.  When ``None`` the policy stays on CPU.

    Returns
    -------
    MAPPOPolicy or None
        A loaded policy in evaluation mode, or ``None`` if the pool is
        empty or loading fails.
    """
    if not self._snapshots:
        return None
    _rng = rng if rng is not None else self._rng
    idx = int(_rng.integers(0, len(self._snapshots)))
    return self._load_snapshot(self._snapshots[idx], device=device)

sample_latest(device=None)

Load and return the most recently added snapshot.

Parameters:

Name Type Description Default
device Optional[str]

Optional PyTorch device string. When provided the loaded policy is moved to device before being returned.

None

Returns:

Type Description
MAPPOPolicy or None

The latest policy, or None if the pool is empty.

Source code in training/self_play.py
def sample_latest(self, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
    """Load and return the most recently added snapshot.

    Parameters
    ----------
    device:
        Optional PyTorch device string.  When provided the loaded
        policy is moved to *device* before being returned.

    Returns
    -------
    MAPPOPolicy or None
        The latest policy, or ``None`` if the pool is empty.
    """
    if not self._snapshots:
        return None
    return self._load_snapshot(self._snapshots[-1], device=device)

WinRateVsPoolCallback

Bases: BaseCallback

Evaluates the current policy vs. a pool opponent and logs win rate.

Every eval_freq environment steps, runs n_eval_episodes episodes in a temporary :class:~envs.battalion_env.BattalionEnv where Red is driven by an opponent sampled from pool. The resulting win rate is logged to W&B (if available) and to the SB3 logger.

Parameters:

Name Type Description Default
pool OpponentPool

:class:OpponentPool to sample opponents from.

required
eval_freq int

How often (in environment steps) to run the evaluation.

50000
n_eval_episodes int

Number of episodes per evaluation (default 20).

20
deterministic bool

Whether the policy acts deterministically during evaluation (default True).

True
use_latest bool

When True, always evaluate against the latest snapshot instead of a random one (default False).

False
verbose int

Verbosity level (0 = silent, 1 = info).

0
Source code in training/self_play.py
class WinRateVsPoolCallback(BaseCallback):
    """Evaluates the current policy vs. a pool opponent and logs win rate.

    Every *eval_freq* environment steps, runs *n_eval_episodes* episodes
    in a temporary :class:`~envs.battalion_env.BattalionEnv` where Red is
    driven by an opponent sampled from *pool*.  The resulting win rate is
    logged to W&B (if available) and to the SB3 logger.

    Parameters
    ----------
    pool:
        :class:`OpponentPool` to sample opponents from.
    eval_freq:
        How often (in environment steps) to run the evaluation.
    n_eval_episodes:
        Number of episodes per evaluation (default 20).
    deterministic:
        Whether the policy acts deterministically during evaluation
        (default ``True``).
    use_latest:
        When ``True``, always evaluate against the *latest* snapshot
        instead of a random one (default ``False``).
    verbose:
        Verbosity level (0 = silent, 1 = info).
    """

    def __init__(
        self,
        pool: OpponentPool,
        eval_freq: int = 50_000,
        n_eval_episodes: int = 20,
        deterministic: bool = True,
        use_latest: bool = False,
        verbose: int = 0,
    ) -> None:
        super().__init__(verbose)
        if int(eval_freq) < 1:
            raise ValueError(f"eval_freq must be >= 1, got {eval_freq}")
        if int(n_eval_episodes) < 1:
            raise ValueError(f"n_eval_episodes must be >= 1, got {n_eval_episodes}")
        self.pool = pool
        self.eval_freq = int(eval_freq)
        self.n_eval_episodes = int(n_eval_episodes)
        self.deterministic = deterministic
        self.use_latest = use_latest

    def _on_step(self) -> bool:
        if self.num_timesteps % self.eval_freq == 0 and self.num_timesteps > 0:
            self._evaluate()
        return True

    def _evaluate(self) -> None:
        """Run evaluation and log the win rate."""
        if self.pool.size == 0:
            return

        opponent = (
            self.pool.sample_latest() if self.use_latest else self.pool.sample()
        )
        if opponent is None:
            return

        win_rate = evaluate_vs_pool(
            model=self.model,
            opponent=opponent,
            n_episodes=self.n_eval_episodes,
            deterministic=self.deterministic,
        )

        if self.verbose >= 1:
            log.info(
                "WinRateVsPoolCallback: win_rate_vs_pool=%.3f (n=%d, step=%d)",
                win_rate,
                self.n_eval_episodes,
                self.num_timesteps,
            )

        # Log to SB3 logger (also picked up by TensorBoard if configured).
        self.logger.record("self_play/win_rate_vs_pool", win_rate)

        # Log to W&B if active.
        try:
            import wandb

            if wandb.run is not None:
                wandb.log(
                    {
                        "self_play/win_rate_vs_pool": win_rate,
                        "time/total_timesteps": self.num_timesteps,
                    },
                    step=self.num_timesteps,
                )
        except ImportError:
            pass

evaluate_team_vs_pool(policy, opponent, n_blue=2, n_red=2, n_episodes=20, deterministic=True, seed=None, env_kwargs=None)

Evaluate a MAPPO policy (Blue) against opponent (Red) in self-play.

Runs n_episodes episodes of :class:~envs.multi_battalion_env.MultiBattalionEnv where Blue is driven by policy and Red is driven by opponent.

For symmetric self-play (n_blue == n_red) the opponent is used directly as a Red policy. When team sizes differ, the opponent's shared actor is applied to each Red agent in turn.

Parameters:

Name Type Description Default
policy 'MAPPOPolicy'

The Blue :class:~models.mappo_policy.MAPPOPolicy under evaluation.

required
opponent 'MAPPOPolicy'

The frozen Red opponent policy.

required
n_blue int

Team sizes (must match the training configuration).

2
n_red int

Team sizes (must match the training configuration).

2
n_episodes int

Number of evaluation episodes (default 20).

20
deterministic bool

Blue acts deterministically when True; Red always acts stochastically to simulate a diverse pool.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided.

None
env_kwargs Optional[Dict]

Extra keyword arguments forwarded to :class:~envs.multi_battalion_env.MultiBattalionEnv.

None

Returns:

Type Description
float

Blue win rate in [0, 1].

Raises:

Type Description
ValueError

If n_episodes < 1.

Source code in training/self_play.py
def evaluate_team_vs_pool(
    policy: "MAPPOPolicy",
    opponent: "MAPPOPolicy",
    n_blue: int = 2,
    n_red: int = 2,
    n_episodes: int = 20,
    deterministic: bool = True,
    seed: Optional[int] = None,
    env_kwargs: Optional[Dict] = None,
) -> float:
    """Evaluate a MAPPO *policy* (Blue) against *opponent* (Red) in self-play.

    Runs *n_episodes* episodes of
    :class:`~envs.multi_battalion_env.MultiBattalionEnv` where Blue is
    driven by *policy* and Red is driven by *opponent*.

    For symmetric self-play (``n_blue == n_red``) the opponent is used
    directly as a Red policy.  When team sizes differ, the opponent's
    shared actor is applied to each Red agent in turn.

    Parameters
    ----------
    policy:
        The Blue :class:`~models.mappo_policy.MAPPOPolicy` under evaluation.
    opponent:
        The frozen Red opponent policy.
    n_blue, n_red:
        Team sizes (must match the training configuration).
    n_episodes:
        Number of evaluation episodes (default 20).
    deterministic:
        Blue acts deterministically when ``True``; Red always acts
        stochastically to simulate a diverse pool.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.
    env_kwargs:
        Extra keyword arguments forwarded to
        :class:`~envs.multi_battalion_env.MultiBattalionEnv`.

    Returns
    -------
    float
        Blue win rate in ``[0, 1]``.

    Raises
    ------
    ValueError
        If *n_episodes* < 1.
    """
    import torch  # local import
    from envs.multi_battalion_env import MultiBattalionEnv

    if n_episodes < 1:
        raise ValueError(f"n_episodes must be >= 1, got {n_episodes}")

    # Derive evaluation device from policy parameters; fall back to CPU if the
    # policy has no parameters (edge case / mocked policies in tests).
    try:
        eval_device = next(policy.parameters()).device
    except StopIteration:
        eval_device = torch.device("cpu")

    # Move opponent to the same device so all tensor operations are consistent.
    opponent = opponent.to(eval_device)
    opponent.eval()

    _env_kwargs: dict = env_kwargs or {}
    env = MultiBattalionEnv(n_blue=n_blue, n_red=n_red, **_env_kwargs)
    act_low = env._act_space.low
    act_high = env._act_space.high
    obs_dim = env._obs_dim

    wins = 0

    for ep in range(n_episodes):
        ep_seed = None if seed is None else seed + ep
        obs, _ = env.reset(seed=ep_seed)
        blue_won = False

        while env.agents:
            action_dict: dict[str, np.ndarray] = {}

            # Blue actions (controlled by *policy*)
            for i in range(n_blue):
                agent_id = f"blue_{i}"
                if agent_id in env.agents:
                    agent_obs = obs.get(agent_id, np.zeros(obs_dim, dtype=np.float32))
                    obs_t = torch.as_tensor(agent_obs, device=eval_device).unsqueeze(0)
                    with torch.no_grad():
                        acts_t, _ = policy.act(obs_t, agent_idx=i, deterministic=deterministic)
                    action_dict[agent_id] = np.clip(
                        acts_t[0].cpu().numpy(), act_low, act_high
                    )

            # Red actions (controlled by *opponent*)
            for i in range(n_red):
                agent_id = f"red_{i}"
                if agent_id in env.agents:
                    agent_obs = obs.get(agent_id, np.zeros(obs_dim, dtype=np.float32))
                    obs_t = torch.as_tensor(agent_obs, device=eval_device).unsqueeze(0)
                    with torch.no_grad():
                        acts_t, _ = opponent.act(
                            obs_t,
                            agent_idx=i % opponent.n_agents,
                            deterministic=False,
                        )
                    action_dict[agent_id] = np.clip(
                        acts_t[0].cpu().numpy(), act_low, act_high
                    )

            obs, _, _, _, _ = env.step(action_dict)

            # Win condition: Red fully eliminated while at least one Blue alive
            red_alive = any(a.startswith("red_") for a in env.agents)
            blue_alive = any(a.startswith("blue_") for a in env.agents)
            if not red_alive and blue_alive and not blue_won:
                blue_won = True

        if blue_won:
            wins += 1

    env.close()
    return wins / n_episodes

evaluate_vs_pool(model, opponent, n_episodes=20, deterministic=True, seed=None)

Evaluate model against opponent in self-play episodes.

Runs n_episodes episodes of :class:~envs.battalion_env.BattalionEnv where Blue is driven by model and Red is driven by opponent.

Parameters:

Name Type Description Default
model PPO

The policy under evaluation (controls Blue).

required
opponent PPO

The frozen snapshot policy (controls Red).

required
n_episodes int

Number of evaluation episodes (default 20).

20
deterministic bool

Whether model acts deterministically (default True). opponent always acts stochastically to simulate a diverse pool.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided.

None

Returns:

Type Description
float

Win rate in [0, 1] (Blue wins / total episodes).

Raises:

Type Description
ValueError

If n_episodes < 1.

Source code in training/self_play.py
def evaluate_vs_pool(
    model: PPO,
    opponent: PPO,
    n_episodes: int = 20,
    deterministic: bool = True,
    seed: Optional[int] = None,
) -> float:
    """Evaluate *model* against *opponent* in self-play episodes.

    Runs *n_episodes* episodes of :class:`~envs.battalion_env.BattalionEnv`
    where Blue is driven by *model* and Red is driven by *opponent*.

    Parameters
    ----------
    model:
        The policy under evaluation (controls Blue).
    opponent:
        The frozen snapshot policy (controls Red).
    n_episodes:
        Number of evaluation episodes (default 20).
    deterministic:
        Whether *model* acts deterministically (default ``True``).
        *opponent* always acts stochastically to simulate a diverse pool.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.

    Returns
    -------
    float
        Win rate in ``[0, 1]`` (Blue wins / total episodes).

    Raises
    ------
    ValueError
        If *n_episodes* < 1.
    """
    if n_episodes < 1:
        raise ValueError(f"n_episodes must be >= 1, got {n_episodes}")

    env = BattalionEnv(red_policy=opponent)
    wins = 0

    for ep in range(n_episodes):
        ep_seed = None if seed is None else seed + ep
        obs, _ = env.reset(seed=ep_seed)
        done = False
        while not done:
            action, _ = model.predict(obs, deterministic=deterministic)
            obs, _reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

        red_lost = (
            info.get("red_routed", False)
            or env.red.strength <= DESTROYED_THRESHOLD  # type: ignore[union-attr]
        )
        blue_lost = (
            info.get("blue_routed", False)
            or env.blue.strength <= DESTROYED_THRESHOLD  # type: ignore[union-attr]
        )
        if red_lost and not blue_lost:
            wins += 1

    env.close()
    return wins / n_episodes

nash_exploitability_proxy(policy, pool, n_blue=2, n_red=2, n_episodes_per_opponent=10, seed=None, env_kwargs=None)

Estimate Nash exploitability using the current self-play pool.

Evaluates policy (as Blue) against every snapshot in pool and returns the nemesis gap:

.. math::

\text{ExplProxy} = \max_i (1 - \text{wr}_i) - \text{mean}_i (1 - \text{wr}_i)

where :math:\text{wr}_i is the Blue win rate against opponent i.

Interpretation:

  • 0.0 — policy performs equally well (or poorly) against every pool member; hard to exploit from the pool.
  • High value — one pool member significantly outperforms the average against the policy; the policy has an exploitable weakness.

Parameters:

Name Type Description Default
policy 'MAPPOPolicy'

The Blue :class:~models.mappo_policy.MAPPOPolicy to evaluate.

required
pool TeamOpponentPool

:class:TeamOpponentPool of frozen opponent snapshots.

required
n_blue int

Team sizes.

2
n_red int

Team sizes.

2
n_episodes_per_opponent int

Episodes per pool member (default 10). Smaller than the regular evaluation budget to keep the cost tractable.

10
seed Optional[int]

Base random seed.

None
env_kwargs Optional[Dict]

Extra kwargs forwarded to :class:~envs.multi_battalion_env.MultiBattalionEnv.

None

Returns:

Type Description
float

Exploitability proxy in [0, 1]. Returns 0.0 if the pool is empty or all snapshots fail to load.

Source code in training/self_play.py
def nash_exploitability_proxy(
    policy: "MAPPOPolicy",
    pool: TeamOpponentPool,
    n_blue: int = 2,
    n_red: int = 2,
    n_episodes_per_opponent: int = 10,
    seed: Optional[int] = None,
    env_kwargs: Optional[Dict] = None,
) -> float:
    """Estimate Nash exploitability using the current self-play pool.

    Evaluates *policy* (as Blue) against **every** snapshot in *pool* and
    returns the *nemesis gap*:

    .. math::

        \\text{ExplProxy} = \\max_i (1 - \\text{wr}_i) - \\text{mean}_i (1 - \\text{wr}_i)

    where :math:`\\text{wr}_i` is the Blue win rate against opponent *i*.

    Interpretation:

    * ``0.0`` — policy performs equally well (or poorly) against every pool
      member; hard to exploit from the pool.
    * High value — one pool member significantly outperforms the average
      against the policy; the policy has an exploitable weakness.

    Parameters
    ----------
    policy:
        The Blue :class:`~models.mappo_policy.MAPPOPolicy` to evaluate.
    pool:
        :class:`TeamOpponentPool` of frozen opponent snapshots.
    n_blue, n_red:
        Team sizes.
    n_episodes_per_opponent:
        Episodes per pool member (default 10).  Smaller than the regular
        evaluation budget to keep the cost tractable.
    seed:
        Base random seed.
    env_kwargs:
        Extra kwargs forwarded to
        :class:`~envs.multi_battalion_env.MultiBattalionEnv`.

    Returns
    -------
    float
        Exploitability proxy in ``[0, 1]``.  Returns ``0.0`` if the pool
        is empty or all snapshots fail to load.
    """
    if pool.size == 0:
        return 0.0

    opp_win_rates: list[float] = []
    for path in pool.snapshot_paths:
        opponent = pool._load_snapshot(path)
        if opponent is None:
            continue
        blue_wr = evaluate_team_vs_pool(
            policy=policy,
            opponent=opponent,
            n_blue=n_blue,
            n_red=n_red,
            n_episodes=n_episodes_per_opponent,
            deterministic=True,
            seed=seed,
            env_kwargs=env_kwargs,
        )
        opp_win_rates.append(1.0 - blue_wr)

    if not opp_win_rates:
        return 0.0

    mean_opp = sum(opp_win_rates) / len(opp_win_rates)
    max_opp = max(opp_win_rates)
    return max_opp - mean_opp

training.self_play.OpponentPool

Fixed-size pool of frozen PPO policy snapshots.

Snapshots are stored as Stable-Baselines3 .zip files under pool_dir. The pool keeps at most max_size snapshots; when full, the oldest snapshot is evicted to make room for the newest one.

Parameters:

Name Type Description Default
pool_dir str | Path

Directory where snapshot .zip files are persisted. Created on first :meth:add if it does not already exist.

required
max_size int

Maximum number of snapshots to retain (default 10).

10

Attributes:

Name Type Description
pool_dir Path

Resolved path of the snapshot directory.

max_size int

Maximum number of snapshots retained in the pool.

Source code in training/self_play.py
class OpponentPool:
    """Fixed-size pool of frozen PPO policy snapshots.

    Snapshots are stored as Stable-Baselines3 ``.zip`` files under
    *pool_dir*.  The pool keeps at most *max_size* snapshots; when full,
    the oldest snapshot is evicted to make room for the newest one.

    Parameters
    ----------
    pool_dir:
        Directory where snapshot ``.zip`` files are persisted.  Created
        on first :meth:`add` if it does not already exist.
    max_size:
        Maximum number of snapshots to retain (default 10).

    Attributes
    ----------
    pool_dir : Path
        Resolved path of the snapshot directory.
    max_size : int
        Maximum number of snapshots retained in the pool.
    """

    def __init__(self, pool_dir: str | Path, max_size: int = 10) -> None:
        if max_size < 1:
            raise ValueError(f"max_size must be >= 1, got {max_size}")
        self.pool_dir = Path(pool_dir)
        self.max_size = int(max_size)
        # Ordered list of snapshot file paths (oldest first).
        self._snapshots: List[Path] = []
        # Shared RNG instance used for uniform sampling when no external RNG is given.
        self._rng: np.random.Generator = np.random.default_rng()
        # Restore existing snapshots from disk if the directory exists.
        self._reload_from_disk()

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    def add(self, model: PPO, version: int) -> Path:
        """Save *model* as a new snapshot and add it to the pool.

        If the pool already contains *max_size* snapshots the oldest is
        removed from disk before saving the new one.

        Parameters
        ----------
        model:
            The current PPO model to snapshot.
        version:
            Monotonically increasing version number embedded in the
            snapshot file name for traceability.

        Returns
        -------
        Path
            Path of the newly saved snapshot file (including ``.zip``).
        """
        self.pool_dir.mkdir(parents=True, exist_ok=True)
        snapshot_path = self.pool_dir / f"snapshot_v{version:06d}"
        model.save(str(snapshot_path))
        full_path = snapshot_path.with_suffix(".zip")

        # Evict oldest if at capacity.
        while len(self._snapshots) >= self.max_size:
            oldest = self._snapshots.pop(0)
            try:
                oldest.unlink(missing_ok=True)
                log.debug("Evicted snapshot %s", oldest)
            except OSError as exc:
                log.warning("Failed to evict snapshot %s: %s", oldest, exc)

        self._snapshots.append(full_path)
        log.info("Saved snapshot %s (pool size %d/%d)", full_path, len(self._snapshots), self.max_size)
        return full_path

    def sample(self, rng: Optional[np.random.Generator] = None) -> Optional[PPO]:
        """Load and return a uniformly sampled snapshot as a PPO model.

        Parameters
        ----------
        rng:
            Optional NumPy random generator for reproducible sampling.
            When ``None``, the pool's internal shared RNG instance is used
            (seeded once at pool construction time).

        Returns
        -------
        PPO or None
            A loaded PPO model, or ``None`` if the pool is empty.
        """
        if not self._snapshots:
            return None
        _rng = rng if rng is not None else self._rng
        idx = int(_rng.integers(0, len(self._snapshots)))
        path = self._snapshots[idx]
        try:
            model = PPO.load(str(path))
            log.debug("Sampled snapshot %s", path)
            return model
        except Exception as exc:
            log.warning("Failed to load snapshot %s: %s", path, exc)
            return None

    def sample_latest(self) -> Optional[PPO]:
        """Load and return the most recently added snapshot.

        Returns
        -------
        PPO or None
            The latest PPO model, or ``None`` if the pool is empty.
        """
        if not self._snapshots:
            return None
        path = self._snapshots[-1]
        try:
            return PPO.load(str(path))
        except Exception as exc:
            log.warning("Failed to load latest snapshot %s: %s", path, exc)
            return None

    @property
    def size(self) -> int:
        """Current number of snapshots in the pool."""
        return len(self._snapshots)

    @property
    def snapshot_paths(self) -> List[Path]:
        """Ordered list of snapshot paths (oldest first, read-only copy)."""
        return list(self._snapshots)

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _reload_from_disk(self) -> None:
        """Populate *_snapshots* from existing files in *pool_dir*.

        Files beyond the pool's *max_size* capacity (the oldest ones) are
        deleted from disk to enforce the pool invariant across restarts.
        """
        if not self.pool_dir.exists():
            return
        existing = sorted(self.pool_dir.glob("snapshot_v*.zip"))
        # Delete excess (oldest) snapshots so on-disk state matches the pool
        # invariant of keeping at most *max_size* files.
        excess = existing[: max(0, len(existing) - self.max_size)]  # guard for len <= max_size
        for path in excess:
            try:
                path.unlink(missing_ok=True)
                log.debug("_reload_from_disk: deleted excess snapshot %s", path)
            except OSError as exc:
                log.warning("_reload_from_disk: failed to delete %s: %s", path, exc)
        self._snapshots = existing[max(0, len(existing) - self.max_size) :]
        if self._snapshots:
            log.info(
                "Restored %d snapshot(s) from %s", len(self._snapshots), self.pool_dir
            )

size property

Current number of snapshots in the pool.

snapshot_paths property

Ordered list of snapshot paths (oldest first, read-only copy).

add(model, version)

Save model as a new snapshot and add it to the pool.

If the pool already contains max_size snapshots the oldest is removed from disk before saving the new one.

Parameters:

Name Type Description Default
model PPO

The current PPO model to snapshot.

required
version int

Monotonically increasing version number embedded in the snapshot file name for traceability.

required

Returns:

Type Description
Path

Path of the newly saved snapshot file (including .zip).

Source code in training/self_play.py
def add(self, model: PPO, version: int) -> Path:
    """Save *model* as a new snapshot and add it to the pool.

    If the pool already contains *max_size* snapshots the oldest is
    removed from disk before saving the new one.

    Parameters
    ----------
    model:
        The current PPO model to snapshot.
    version:
        Monotonically increasing version number embedded in the
        snapshot file name for traceability.

    Returns
    -------
    Path
        Path of the newly saved snapshot file (including ``.zip``).
    """
    self.pool_dir.mkdir(parents=True, exist_ok=True)
    snapshot_path = self.pool_dir / f"snapshot_v{version:06d}"
    model.save(str(snapshot_path))
    full_path = snapshot_path.with_suffix(".zip")

    # Evict oldest if at capacity.
    while len(self._snapshots) >= self.max_size:
        oldest = self._snapshots.pop(0)
        try:
            oldest.unlink(missing_ok=True)
            log.debug("Evicted snapshot %s", oldest)
        except OSError as exc:
            log.warning("Failed to evict snapshot %s: %s", oldest, exc)

    self._snapshots.append(full_path)
    log.info("Saved snapshot %s (pool size %d/%d)", full_path, len(self._snapshots), self.max_size)
    return full_path

sample(rng=None)

Load and return a uniformly sampled snapshot as a PPO model.

Parameters:

Name Type Description Default
rng Optional[Generator]

Optional NumPy random generator for reproducible sampling. When None, the pool's internal shared RNG instance is used (seeded once at pool construction time).

None

Returns:

Type Description
PPO or None

A loaded PPO model, or None if the pool is empty.

Source code in training/self_play.py
def sample(self, rng: Optional[np.random.Generator] = None) -> Optional[PPO]:
    """Load and return a uniformly sampled snapshot as a PPO model.

    Parameters
    ----------
    rng:
        Optional NumPy random generator for reproducible sampling.
        When ``None``, the pool's internal shared RNG instance is used
        (seeded once at pool construction time).

    Returns
    -------
    PPO or None
        A loaded PPO model, or ``None`` if the pool is empty.
    """
    if not self._snapshots:
        return None
    _rng = rng if rng is not None else self._rng
    idx = int(_rng.integers(0, len(self._snapshots)))
    path = self._snapshots[idx]
    try:
        model = PPO.load(str(path))
        log.debug("Sampled snapshot %s", path)
        return model
    except Exception as exc:
        log.warning("Failed to load snapshot %s: %s", path, exc)
        return None

sample_latest()

Load and return the most recently added snapshot.

Returns:

Type Description
PPO or None

The latest PPO model, or None if the pool is empty.

Source code in training/self_play.py
def sample_latest(self) -> Optional[PPO]:
    """Load and return the most recently added snapshot.

    Returns
    -------
    PPO or None
        The latest PPO model, or ``None`` if the pool is empty.
    """
    if not self._snapshots:
        return None
    path = self._snapshots[-1]
    try:
        return PPO.load(str(path))
    except Exception as exc:
        log.warning("Failed to load latest snapshot %s: %s", path, exc)
        return None

training.self_play.SelfPlayCallback

Bases: BaseCallback

Periodically snapshots the current policy and updates the Red opponent.

Every snapshot_freq environment steps the current model is saved to the :class:OpponentPool. If the pool contains at least one snapshot, a uniformly sampled opponent is loaded and injected into each environment in vec_env via :meth:~envs.battalion_env.BattalionEnv.set_red_policy.

Parameters:

Name Type Description Default
pool OpponentPool

The :class:OpponentPool to save snapshots into.

required
snapshot_freq int

How often (in environment steps) to take a snapshot.

50000
vec_env Optional[VecEnv]

The vectorized training environment whose Red opponents should be updated. When None, the callback uses self.training_env (set automatically by SB3 during model.learn()).

None
verbose int

Verbosity level (0 = silent, 1 = info).

0
Source code in training/self_play.py
class SelfPlayCallback(BaseCallback):
    """Periodically snapshots the current policy and updates the Red opponent.

    Every *snapshot_freq* environment steps the current model is saved to
    the :class:`OpponentPool`.  If the pool contains at least one snapshot,
    a uniformly sampled opponent is loaded and injected into each
    environment in *vec_env* via :meth:`~envs.battalion_env.BattalionEnv.set_red_policy`.

    Parameters
    ----------
    pool:
        The :class:`OpponentPool` to save snapshots into.
    snapshot_freq:
        How often (in environment steps) to take a snapshot.
    vec_env:
        The vectorized training environment whose Red opponents should be
        updated.  When ``None``, the callback uses
        ``self.training_env`` (set automatically by SB3 during
        ``model.learn()``).
    verbose:
        Verbosity level (0 = silent, 1 = info).
    """

    def __init__(
        self,
        pool: OpponentPool,
        snapshot_freq: int = 50_000,
        vec_env: Optional[VecEnv] = None,
        verbose: int = 0,
        manifest: Optional[CheckpointManifest] = None,
        seed: int = 0,
        curriculum_level: int = 5,
        run_id: Optional[str] = None,
        config_hash: str = "",
    ) -> None:
        super().__init__(verbose)
        if int(snapshot_freq) < 1:
            raise ValueError(f"snapshot_freq must be >= 1, got {snapshot_freq}")
        self.pool = pool
        self.snapshot_freq = int(snapshot_freq)
        self._vec_env = vec_env
        # Initialize version counter from any snapshots already in the pool so
        # that a training restart doesn't overwrite existing snapshot files.
        self._version: int = _max_version_in_pool(pool)
        # Provenance manifest — optional; when set, every snapshot is indexed.
        self._sp_manifest = manifest
        self._sp_seed = int(seed)
        self._sp_curriculum_level = int(curriculum_level)
        self._sp_run_id = run_id
        self._sp_config_hash = str(config_hash)

    def _on_step(self) -> bool:
        if self.num_timesteps % self.snapshot_freq == 0 and self.num_timesteps > 0:
            self._take_snapshot_and_update()
        return True

    def _take_snapshot_and_update(self) -> None:
        """Save current model to pool and refresh all Red opponents."""
        self._version += 1
        snapshot_path = self.pool.add(self.model, self._version)

        if self._sp_manifest is not None:
            self._sp_manifest.register(
                snapshot_path,
                artifact_type="self_play_snapshot",
                seed=self._sp_seed,
                curriculum_level=self._sp_curriculum_level,
                run_id=self._sp_run_id,
                config_hash=self._sp_config_hash,
                step=int(self.num_timesteps),
            )

        opponent = self.pool.sample()
        if opponent is None:
            return

        env = self._vec_env if self._vec_env is not None else self.training_env
        if env is None:
            log.warning("SelfPlayCallback: no environment available to update Red policy.")
            return

        # Propagate to every sub-environment.
        for env_instance in _iter_envs(env):
            env_instance.set_red_policy(opponent)

        if self.verbose >= 1:
            log.info(
                "SelfPlayCallback: snapshot v%d saved; Red policy updated (pool=%d).",
                self._version,
                self.pool.size,
            )

training.self_play.WinRateVsPoolCallback

Bases: BaseCallback

Evaluates the current policy vs. a pool opponent and logs win rate.

Every eval_freq environment steps, runs n_eval_episodes episodes in a temporary :class:~envs.battalion_env.BattalionEnv where Red is driven by an opponent sampled from pool. The resulting win rate is logged to W&B (if available) and to the SB3 logger.

Parameters:

Name Type Description Default
pool OpponentPool

:class:OpponentPool to sample opponents from.

required
eval_freq int

How often (in environment steps) to run the evaluation.

50000
n_eval_episodes int

Number of episodes per evaluation (default 20).

20
deterministic bool

Whether the policy acts deterministically during evaluation (default True).

True
use_latest bool

When True, always evaluate against the latest snapshot instead of a random one (default False).

False
verbose int

Verbosity level (0 = silent, 1 = info).

0
Source code in training/self_play.py
class WinRateVsPoolCallback(BaseCallback):
    """Evaluates the current policy vs. a pool opponent and logs win rate.

    Every *eval_freq* environment steps, runs *n_eval_episodes* episodes
    in a temporary :class:`~envs.battalion_env.BattalionEnv` where Red is
    driven by an opponent sampled from *pool*.  The resulting win rate is
    logged to W&B (if available) and to the SB3 logger.

    Parameters
    ----------
    pool:
        :class:`OpponentPool` to sample opponents from.
    eval_freq:
        How often (in environment steps) to run the evaluation.
    n_eval_episodes:
        Number of episodes per evaluation (default 20).
    deterministic:
        Whether the policy acts deterministically during evaluation
        (default ``True``).
    use_latest:
        When ``True``, always evaluate against the *latest* snapshot
        instead of a random one (default ``False``).
    verbose:
        Verbosity level (0 = silent, 1 = info).
    """

    def __init__(
        self,
        pool: OpponentPool,
        eval_freq: int = 50_000,
        n_eval_episodes: int = 20,
        deterministic: bool = True,
        use_latest: bool = False,
        verbose: int = 0,
    ) -> None:
        super().__init__(verbose)
        if int(eval_freq) < 1:
            raise ValueError(f"eval_freq must be >= 1, got {eval_freq}")
        if int(n_eval_episodes) < 1:
            raise ValueError(f"n_eval_episodes must be >= 1, got {n_eval_episodes}")
        self.pool = pool
        self.eval_freq = int(eval_freq)
        self.n_eval_episodes = int(n_eval_episodes)
        self.deterministic = deterministic
        self.use_latest = use_latest

    def _on_step(self) -> bool:
        if self.num_timesteps % self.eval_freq == 0 and self.num_timesteps > 0:
            self._evaluate()
        return True

    def _evaluate(self) -> None:
        """Run evaluation and log the win rate."""
        if self.pool.size == 0:
            return

        opponent = (
            self.pool.sample_latest() if self.use_latest else self.pool.sample()
        )
        if opponent is None:
            return

        win_rate = evaluate_vs_pool(
            model=self.model,
            opponent=opponent,
            n_episodes=self.n_eval_episodes,
            deterministic=self.deterministic,
        )

        if self.verbose >= 1:
            log.info(
                "WinRateVsPoolCallback: win_rate_vs_pool=%.3f (n=%d, step=%d)",
                win_rate,
                self.n_eval_episodes,
                self.num_timesteps,
            )

        # Log to SB3 logger (also picked up by TensorBoard if configured).
        self.logger.record("self_play/win_rate_vs_pool", win_rate)

        # Log to W&B if active.
        try:
            import wandb

            if wandb.run is not None:
                wandb.log(
                    {
                        "self_play/win_rate_vs_pool": win_rate,
                        "time/total_timesteps": self.num_timesteps,
                    },
                    step=self.num_timesteps,
                )
        except ImportError:
            pass

training.self_play.evaluate_vs_pool(model, opponent, n_episodes=20, deterministic=True, seed=None)

Evaluate model against opponent in self-play episodes.

Runs n_episodes episodes of :class:~envs.battalion_env.BattalionEnv where Blue is driven by model and Red is driven by opponent.

Parameters:

Name Type Description Default
model PPO

The policy under evaluation (controls Blue).

required
opponent PPO

The frozen snapshot policy (controls Red).

required
n_episodes int

Number of evaluation episodes (default 20).

20
deterministic bool

Whether model acts deterministically (default True). opponent always acts stochastically to simulate a diverse pool.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided.

None

Returns:

Type Description
float

Win rate in [0, 1] (Blue wins / total episodes).

Raises:

Type Description
ValueError

If n_episodes < 1.

Source code in training/self_play.py
def evaluate_vs_pool(
    model: PPO,
    opponent: PPO,
    n_episodes: int = 20,
    deterministic: bool = True,
    seed: Optional[int] = None,
) -> float:
    """Evaluate *model* against *opponent* in self-play episodes.

    Runs *n_episodes* episodes of :class:`~envs.battalion_env.BattalionEnv`
    where Blue is driven by *model* and Red is driven by *opponent*.

    Parameters
    ----------
    model:
        The policy under evaluation (controls Blue).
    opponent:
        The frozen snapshot policy (controls Red).
    n_episodes:
        Number of evaluation episodes (default 20).
    deterministic:
        Whether *model* acts deterministically (default ``True``).
        *opponent* always acts stochastically to simulate a diverse pool.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.

    Returns
    -------
    float
        Win rate in ``[0, 1]`` (Blue wins / total episodes).

    Raises
    ------
    ValueError
        If *n_episodes* < 1.
    """
    if n_episodes < 1:
        raise ValueError(f"n_episodes must be >= 1, got {n_episodes}")

    env = BattalionEnv(red_policy=opponent)
    wins = 0

    for ep in range(n_episodes):
        ep_seed = None if seed is None else seed + ep
        obs, _ = env.reset(seed=ep_seed)
        done = False
        while not done:
            action, _ = model.predict(obs, deterministic=deterministic)
            obs, _reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

        red_lost = (
            info.get("red_routed", False)
            or env.red.strength <= DESTROYED_THRESHOLD  # type: ignore[union-attr]
        )
        blue_lost = (
            info.get("blue_routed", False)
            or env.blue.strength <= DESTROYED_THRESHOLD  # type: ignore[union-attr]
        )
        if red_lost and not blue_lost:
            wins += 1

    env.close()
    return wins / n_episodes

training.self_play.TeamOpponentPool

Fixed-size pool of frozen :class:~models.mappo_policy.MAPPOPolicy snapshots.

Snapshots are stored as PyTorch .pt files under pool_dir. Each file contains the policy state_dict plus the constructor kwargs (obs_dim, action_dim, state_dim, n_agents, share_parameters) needed to reconstruct the policy at load time. The pool keeps at most max_size snapshots; when full, the oldest is evicted to make room for the newest.

Parameters:

Name Type Description Default
pool_dir str | Path

Directory where snapshot .pt files are persisted. Created on first :meth:add if it does not already exist.

required
max_size int

Maximum number of snapshots to retain (default 10).

10
Source code in training/self_play.py
class TeamOpponentPool:
    """Fixed-size pool of frozen :class:`~models.mappo_policy.MAPPOPolicy` snapshots.

    Snapshots are stored as PyTorch ``.pt`` files under *pool_dir*.  Each
    file contains the policy ``state_dict`` plus the constructor kwargs
    (``obs_dim``, ``action_dim``, ``state_dim``, ``n_agents``,
    ``share_parameters``) needed to reconstruct the policy at load time.
    The pool keeps at most *max_size* snapshots; when full, the oldest is
    evicted to make room for the newest.

    Parameters
    ----------
    pool_dir:
        Directory where snapshot ``.pt`` files are persisted.  Created on
        first :meth:`add` if it does not already exist.
    max_size:
        Maximum number of snapshots to retain (default 10).
    """

    def __init__(self, pool_dir: str | Path, max_size: int = 10) -> None:
        if max_size < 1:
            raise ValueError(f"max_size must be >= 1, got {max_size}")
        self.pool_dir = Path(pool_dir)
        self.max_size = int(max_size)
        self._snapshots: List[Path] = []
        self._rng: np.random.Generator = np.random.default_rng()
        self._reload_from_disk()

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    def add(self, policy: "MAPPOPolicy", version: int) -> Path:
        """Save *policy* as a new snapshot and add it to the pool.

        The snapshot stores both the ``state_dict`` and the constructor
        kwargs so the policy can be fully reconstructed later.

        Parameters
        ----------
        policy:
            The :class:`~models.mappo_policy.MAPPOPolicy` to snapshot.
        version:
            Monotonically increasing version number embedded in the file
            name for traceability.

        Returns
        -------
        Path
            Path of the newly saved snapshot file.
        """
        import torch  # local import to avoid hard dependency at module load

        self.pool_dir.mkdir(parents=True, exist_ok=True)
        snapshot_path = self.pool_dir / f"team_snapshot_v{version:06d}.pt"

        torch.save(
            {
                "state_dict": policy.state_dict(),
                "kwargs": {
                    "obs_dim": policy.obs_dim,
                    "action_dim": policy.action_dim,
                    "state_dim": policy.state_dim,
                    "n_agents": policy.n_agents,
                    "share_parameters": policy.share_parameters,
                    "actor_hidden_sizes": policy.actor_hidden_sizes,
                    "critic_hidden_sizes": policy.critic_hidden_sizes,
                },
            },
            snapshot_path,
        )

        # Evict oldest if at capacity.
        while len(self._snapshots) >= self.max_size:
            oldest = self._snapshots.pop(0)
            try:
                oldest.unlink(missing_ok=True)
                log.debug("TeamOpponentPool: evicted snapshot %s", oldest)
            except OSError as exc:
                log.warning("TeamOpponentPool: failed to evict %s: %s", oldest, exc)

        self._snapshots.append(snapshot_path)
        log.info(
            "TeamOpponentPool: saved snapshot %s (pool=%d/%d)",
            snapshot_path,
            len(self._snapshots),
            self.max_size,
        )
        return snapshot_path

    def sample(self, rng: Optional[np.random.Generator] = None, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
        """Load and return a uniformly sampled snapshot.

        Parameters
        ----------
        rng:
            Optional NumPy random generator for reproducible sampling.
            When ``None``, the pool's internal RNG is used.
        device:
            Optional PyTorch device string (e.g. ``"cuda:0"``).  When
            provided the loaded policy is moved to *device* before being
            returned.  When ``None`` the policy stays on CPU.

        Returns
        -------
        MAPPOPolicy or None
            A loaded policy in evaluation mode, or ``None`` if the pool is
            empty or loading fails.
        """
        if not self._snapshots:
            return None
        _rng = rng if rng is not None else self._rng
        idx = int(_rng.integers(0, len(self._snapshots)))
        return self._load_snapshot(self._snapshots[idx], device=device)

    def sample_latest(self, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
        """Load and return the most recently added snapshot.

        Parameters
        ----------
        device:
            Optional PyTorch device string.  When provided the loaded
            policy is moved to *device* before being returned.

        Returns
        -------
        MAPPOPolicy or None
            The latest policy, or ``None`` if the pool is empty.
        """
        if not self._snapshots:
            return None
        return self._load_snapshot(self._snapshots[-1], device=device)

    @property
    def size(self) -> int:
        """Current number of snapshots in the pool."""
        return len(self._snapshots)

    @property
    def snapshot_paths(self) -> List[Path]:
        """Ordered list of snapshot paths (oldest first, read-only copy)."""
        return list(self._snapshots)

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _load_snapshot(self, path: Path, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
        """Load a :class:`~models.mappo_policy.MAPPOPolicy` from *path*.

        Parameters
        ----------
        path:
            Path to the ``.pt`` snapshot file.
        device:
            Optional PyTorch device string.  When provided the loaded
            policy is moved to *device* after loading (which always uses
            CPU as the intermediate map_location for safety).
        """
        import torch  # local import
        from models.mappo_policy import MAPPOPolicy

        try:
            # weights_only=True restricts deserialization to safe tensor/primitive
            # types, mitigating arbitrary-code-execution risk from tampered files.
            data = torch.load(str(path), map_location="cpu", weights_only=True)
            policy = MAPPOPolicy(**data["kwargs"])
            policy.load_state_dict(data["state_dict"])
            if device is not None:
                policy = policy.to(device)
            policy.eval()
            actual_device = next(policy.parameters()).device
            log.debug("TeamOpponentPool: loaded snapshot %s (device=%s)", path, actual_device)
            return policy
        except Exception as exc:
            log.warning("TeamOpponentPool: failed to load %s: %s", path, exc)
            return None

    def _reload_from_disk(self) -> None:
        """Populate *_snapshots* from existing files in *pool_dir*."""
        if not self.pool_dir.exists():
            return
        existing = sorted(self.pool_dir.glob("team_snapshot_v*.pt"))
        excess = existing[: max(0, len(existing) - self.max_size)]
        for path in excess:
            try:
                path.unlink(missing_ok=True)
                log.debug("TeamOpponentPool._reload_from_disk: deleted excess %s", path)
            except OSError as exc:
                log.warning(
                    "TeamOpponentPool._reload_from_disk: failed to delete %s: %s", path, exc
                )
        self._snapshots = existing[max(0, len(existing) - self.max_size) :]
        if self._snapshots:
            log.info(
                "TeamOpponentPool: restored %d snapshot(s) from %s",
                len(self._snapshots),
                self.pool_dir,
            )

size property

Current number of snapshots in the pool.

snapshot_paths property

Ordered list of snapshot paths (oldest first, read-only copy).

add(policy, version)

Save policy as a new snapshot and add it to the pool.

The snapshot stores both the state_dict and the constructor kwargs so the policy can be fully reconstructed later.

Parameters:

Name Type Description Default
policy 'MAPPOPolicy'

The :class:~models.mappo_policy.MAPPOPolicy to snapshot.

required
version int

Monotonically increasing version number embedded in the file name for traceability.

required

Returns:

Type Description
Path

Path of the newly saved snapshot file.

Source code in training/self_play.py
def add(self, policy: "MAPPOPolicy", version: int) -> Path:
    """Save *policy* as a new snapshot and add it to the pool.

    The snapshot stores both the ``state_dict`` and the constructor
    kwargs so the policy can be fully reconstructed later.

    Parameters
    ----------
    policy:
        The :class:`~models.mappo_policy.MAPPOPolicy` to snapshot.
    version:
        Monotonically increasing version number embedded in the file
        name for traceability.

    Returns
    -------
    Path
        Path of the newly saved snapshot file.
    """
    import torch  # local import to avoid hard dependency at module load

    self.pool_dir.mkdir(parents=True, exist_ok=True)
    snapshot_path = self.pool_dir / f"team_snapshot_v{version:06d}.pt"

    torch.save(
        {
            "state_dict": policy.state_dict(),
            "kwargs": {
                "obs_dim": policy.obs_dim,
                "action_dim": policy.action_dim,
                "state_dim": policy.state_dim,
                "n_agents": policy.n_agents,
                "share_parameters": policy.share_parameters,
                "actor_hidden_sizes": policy.actor_hidden_sizes,
                "critic_hidden_sizes": policy.critic_hidden_sizes,
            },
        },
        snapshot_path,
    )

    # Evict oldest if at capacity.
    while len(self._snapshots) >= self.max_size:
        oldest = self._snapshots.pop(0)
        try:
            oldest.unlink(missing_ok=True)
            log.debug("TeamOpponentPool: evicted snapshot %s", oldest)
        except OSError as exc:
            log.warning("TeamOpponentPool: failed to evict %s: %s", oldest, exc)

    self._snapshots.append(snapshot_path)
    log.info(
        "TeamOpponentPool: saved snapshot %s (pool=%d/%d)",
        snapshot_path,
        len(self._snapshots),
        self.max_size,
    )
    return snapshot_path

sample(rng=None, device=None)

Load and return a uniformly sampled snapshot.

Parameters:

Name Type Description Default
rng Optional[Generator]

Optional NumPy random generator for reproducible sampling. When None, the pool's internal RNG is used.

None
device Optional[str]

Optional PyTorch device string (e.g. "cuda:0"). When provided the loaded policy is moved to device before being returned. When None the policy stays on CPU.

None

Returns:

Type Description
MAPPOPolicy or None

A loaded policy in evaluation mode, or None if the pool is empty or loading fails.

Source code in training/self_play.py
def sample(self, rng: Optional[np.random.Generator] = None, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
    """Load and return a uniformly sampled snapshot.

    Parameters
    ----------
    rng:
        Optional NumPy random generator for reproducible sampling.
        When ``None``, the pool's internal RNG is used.
    device:
        Optional PyTorch device string (e.g. ``"cuda:0"``).  When
        provided the loaded policy is moved to *device* before being
        returned.  When ``None`` the policy stays on CPU.

    Returns
    -------
    MAPPOPolicy or None
        A loaded policy in evaluation mode, or ``None`` if the pool is
        empty or loading fails.
    """
    if not self._snapshots:
        return None
    _rng = rng if rng is not None else self._rng
    idx = int(_rng.integers(0, len(self._snapshots)))
    return self._load_snapshot(self._snapshots[idx], device=device)

sample_latest(device=None)

Load and return the most recently added snapshot.

Parameters:

Name Type Description Default
device Optional[str]

Optional PyTorch device string. When provided the loaded policy is moved to device before being returned.

None

Returns:

Type Description
MAPPOPolicy or None

The latest policy, or None if the pool is empty.

Source code in training/self_play.py
def sample_latest(self, device: Optional[str] = None) -> Optional["MAPPOPolicy"]:
    """Load and return the most recently added snapshot.

    Parameters
    ----------
    device:
        Optional PyTorch device string.  When provided the loaded
        policy is moved to *device* before being returned.

    Returns
    -------
    MAPPOPolicy or None
        The latest policy, or ``None`` if the pool is empty.
    """
    if not self._snapshots:
        return None
    return self._load_snapshot(self._snapshots[-1], device=device)

training.self_play.evaluate_team_vs_pool(policy, opponent, n_blue=2, n_red=2, n_episodes=20, deterministic=True, seed=None, env_kwargs=None)

Evaluate a MAPPO policy (Blue) against opponent (Red) in self-play.

Runs n_episodes episodes of :class:~envs.multi_battalion_env.MultiBattalionEnv where Blue is driven by policy and Red is driven by opponent.

For symmetric self-play (n_blue == n_red) the opponent is used directly as a Red policy. When team sizes differ, the opponent's shared actor is applied to each Red agent in turn.

Parameters:

Name Type Description Default
policy 'MAPPOPolicy'

The Blue :class:~models.mappo_policy.MAPPOPolicy under evaluation.

required
opponent 'MAPPOPolicy'

The frozen Red opponent policy.

required
n_blue int

Team sizes (must match the training configuration).

2
n_red int

Team sizes (must match the training configuration).

2
n_episodes int

Number of evaluation episodes (default 20).

20
deterministic bool

Blue acts deterministically when True; Red always acts stochastically to simulate a diverse pool.

True
seed Optional[int]

Base random seed; episode i uses seed + i when provided.

None
env_kwargs Optional[Dict]

Extra keyword arguments forwarded to :class:~envs.multi_battalion_env.MultiBattalionEnv.

None

Returns:

Type Description
float

Blue win rate in [0, 1].

Raises:

Type Description
ValueError

If n_episodes < 1.

Source code in training/self_play.py
def evaluate_team_vs_pool(
    policy: "MAPPOPolicy",
    opponent: "MAPPOPolicy",
    n_blue: int = 2,
    n_red: int = 2,
    n_episodes: int = 20,
    deterministic: bool = True,
    seed: Optional[int] = None,
    env_kwargs: Optional[Dict] = None,
) -> float:
    """Evaluate a MAPPO *policy* (Blue) against *opponent* (Red) in self-play.

    Runs *n_episodes* episodes of
    :class:`~envs.multi_battalion_env.MultiBattalionEnv` where Blue is
    driven by *policy* and Red is driven by *opponent*.

    For symmetric self-play (``n_blue == n_red``) the opponent is used
    directly as a Red policy.  When team sizes differ, the opponent's
    shared actor is applied to each Red agent in turn.

    Parameters
    ----------
    policy:
        The Blue :class:`~models.mappo_policy.MAPPOPolicy` under evaluation.
    opponent:
        The frozen Red opponent policy.
    n_blue, n_red:
        Team sizes (must match the training configuration).
    n_episodes:
        Number of evaluation episodes (default 20).
    deterministic:
        Blue acts deterministically when ``True``; Red always acts
        stochastically to simulate a diverse pool.
    seed:
        Base random seed; episode *i* uses ``seed + i`` when provided.
    env_kwargs:
        Extra keyword arguments forwarded to
        :class:`~envs.multi_battalion_env.MultiBattalionEnv`.

    Returns
    -------
    float
        Blue win rate in ``[0, 1]``.

    Raises
    ------
    ValueError
        If *n_episodes* < 1.
    """
    import torch  # local import
    from envs.multi_battalion_env import MultiBattalionEnv

    if n_episodes < 1:
        raise ValueError(f"n_episodes must be >= 1, got {n_episodes}")

    # Derive evaluation device from policy parameters; fall back to CPU if the
    # policy has no parameters (edge case / mocked policies in tests).
    try:
        eval_device = next(policy.parameters()).device
    except StopIteration:
        eval_device = torch.device("cpu")

    # Move opponent to the same device so all tensor operations are consistent.
    opponent = opponent.to(eval_device)
    opponent.eval()

    _env_kwargs: dict = env_kwargs or {}
    env = MultiBattalionEnv(n_blue=n_blue, n_red=n_red, **_env_kwargs)
    act_low = env._act_space.low
    act_high = env._act_space.high
    obs_dim = env._obs_dim

    wins = 0

    for ep in range(n_episodes):
        ep_seed = None if seed is None else seed + ep
        obs, _ = env.reset(seed=ep_seed)
        blue_won = False

        while env.agents:
            action_dict: dict[str, np.ndarray] = {}

            # Blue actions (controlled by *policy*)
            for i in range(n_blue):
                agent_id = f"blue_{i}"
                if agent_id in env.agents:
                    agent_obs = obs.get(agent_id, np.zeros(obs_dim, dtype=np.float32))
                    obs_t = torch.as_tensor(agent_obs, device=eval_device).unsqueeze(0)
                    with torch.no_grad():
                        acts_t, _ = policy.act(obs_t, agent_idx=i, deterministic=deterministic)
                    action_dict[agent_id] = np.clip(
                        acts_t[0].cpu().numpy(), act_low, act_high
                    )

            # Red actions (controlled by *opponent*)
            for i in range(n_red):
                agent_id = f"red_{i}"
                if agent_id in env.agents:
                    agent_obs = obs.get(agent_id, np.zeros(obs_dim, dtype=np.float32))
                    obs_t = torch.as_tensor(agent_obs, device=eval_device).unsqueeze(0)
                    with torch.no_grad():
                        acts_t, _ = opponent.act(
                            obs_t,
                            agent_idx=i % opponent.n_agents,
                            deterministic=False,
                        )
                    action_dict[agent_id] = np.clip(
                        acts_t[0].cpu().numpy(), act_low, act_high
                    )

            obs, _, _, _, _ = env.step(action_dict)

            # Win condition: Red fully eliminated while at least one Blue alive
            red_alive = any(a.startswith("red_") for a in env.agents)
            blue_alive = any(a.startswith("blue_") for a in env.agents)
            if not red_alive and blue_alive and not blue_won:
                blue_won = True

        if blue_won:
            wins += 1

    env.close()
    return wins / n_episodes

training.self_play.nash_exploitability_proxy(policy, pool, n_blue=2, n_red=2, n_episodes_per_opponent=10, seed=None, env_kwargs=None)

Estimate Nash exploitability using the current self-play pool.

Evaluates policy (as Blue) against every snapshot in pool and returns the nemesis gap:

.. math::

\text{ExplProxy} = \max_i (1 - \text{wr}_i) - \text{mean}_i (1 - \text{wr}_i)

where :math:\text{wr}_i is the Blue win rate against opponent i.

Interpretation:

  • 0.0 — policy performs equally well (or poorly) against every pool member; hard to exploit from the pool.
  • High value — one pool member significantly outperforms the average against the policy; the policy has an exploitable weakness.

Parameters:

Name Type Description Default
policy 'MAPPOPolicy'

The Blue :class:~models.mappo_policy.MAPPOPolicy to evaluate.

required
pool TeamOpponentPool

:class:TeamOpponentPool of frozen opponent snapshots.

required
n_blue int

Team sizes.

2
n_red int

Team sizes.

2
n_episodes_per_opponent int

Episodes per pool member (default 10). Smaller than the regular evaluation budget to keep the cost tractable.

10
seed Optional[int]

Base random seed.

None
env_kwargs Optional[Dict]

Extra kwargs forwarded to :class:~envs.multi_battalion_env.MultiBattalionEnv.

None

Returns:

Type Description
float

Exploitability proxy in [0, 1]. Returns 0.0 if the pool is empty or all snapshots fail to load.

Source code in training/self_play.py
def nash_exploitability_proxy(
    policy: "MAPPOPolicy",
    pool: TeamOpponentPool,
    n_blue: int = 2,
    n_red: int = 2,
    n_episodes_per_opponent: int = 10,
    seed: Optional[int] = None,
    env_kwargs: Optional[Dict] = None,
) -> float:
    """Estimate Nash exploitability using the current self-play pool.

    Evaluates *policy* (as Blue) against **every** snapshot in *pool* and
    returns the *nemesis gap*:

    .. math::

        \\text{ExplProxy} = \\max_i (1 - \\text{wr}_i) - \\text{mean}_i (1 - \\text{wr}_i)

    where :math:`\\text{wr}_i` is the Blue win rate against opponent *i*.

    Interpretation:

    * ``0.0`` — policy performs equally well (or poorly) against every pool
      member; hard to exploit from the pool.
    * High value — one pool member significantly outperforms the average
      against the policy; the policy has an exploitable weakness.

    Parameters
    ----------
    policy:
        The Blue :class:`~models.mappo_policy.MAPPOPolicy` to evaluate.
    pool:
        :class:`TeamOpponentPool` of frozen opponent snapshots.
    n_blue, n_red:
        Team sizes.
    n_episodes_per_opponent:
        Episodes per pool member (default 10).  Smaller than the regular
        evaluation budget to keep the cost tractable.
    seed:
        Base random seed.
    env_kwargs:
        Extra kwargs forwarded to
        :class:`~envs.multi_battalion_env.MultiBattalionEnv`.

    Returns
    -------
    float
        Exploitability proxy in ``[0, 1]``.  Returns ``0.0`` if the pool
        is empty or all snapshots fail to load.
    """
    if pool.size == 0:
        return 0.0

    opp_win_rates: list[float] = []
    for path in pool.snapshot_paths:
        opponent = pool._load_snapshot(path)
        if opponent is None:
            continue
        blue_wr = evaluate_team_vs_pool(
            policy=policy,
            opponent=opponent,
            n_blue=n_blue,
            n_red=n_red,
            n_episodes=n_episodes_per_opponent,
            deterministic=True,
            seed=seed,
            env_kwargs=env_kwargs,
        )
        opp_win_rates.append(1.0 - blue_wr)

    if not opp_win_rates:
        return 0.0

    mean_opp = sum(opp_win_rates) / len(opp_win_rates)
    max_opp = max(opp_win_rates)
    return max_opp - mean_opp

Curriculum

training.curriculum_scheduler.CurriculumScheduler

Tracks episode outcomes and decides when to promote the curriculum stage.

Parameters:

Name Type Description Default
promote_threshold float

Rolling win rate (in [0, 1]) that must be reached to advance to the next stage. Defaults to 0.70 (70 % win rate).

0.7
win_rate_window int

Number of most-recent episodes used to compute the rolling win rate. Defaults to 50.

50
initial_stage CurriculumStage

The curriculum stage to begin from. Defaults to :attr:CurriculumStage.STAGE_1V1.

STAGE_1V1
Source code in training/curriculum_scheduler.py
class CurriculumScheduler:
    """Tracks episode outcomes and decides when to promote the curriculum stage.

    Parameters
    ----------
    promote_threshold:
        Rolling win rate (in ``[0, 1]``) that must be reached to advance to
        the next stage.  Defaults to ``0.70`` (70 % win rate).
    win_rate_window:
        Number of most-recent episodes used to compute the rolling win rate.
        Defaults to ``50``.
    initial_stage:
        The curriculum stage to begin from.  Defaults to
        :attr:`CurriculumStage.STAGE_1V1`.
    """

    def __init__(
        self,
        promote_threshold: float = 0.70,
        win_rate_window: int = 50,
        initial_stage: CurriculumStage = CurriculumStage.STAGE_1V1,
    ) -> None:
        if not (0.0 < promote_threshold <= 1.0):
            raise ValueError(
                f"promote_threshold must be in (0, 1], got {promote_threshold}"
            )
        if win_rate_window < 1:
            raise ValueError(
                f"win_rate_window must be >= 1, got {win_rate_window}"
            )

        self.promote_threshold = float(promote_threshold)
        self.win_rate_window = int(win_rate_window)
        self._stage: CurriculumStage = initial_stage

        # Rolling window of episode outcomes (True=win, False=loss/draw)
        self._outcomes: Deque[bool] = deque(maxlen=win_rate_window)

        # Cumulative episode counter
        self._total_episodes: int = 0

    # ------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------

    @property
    def stage(self) -> CurriculumStage:
        """The current curriculum stage."""
        return self._stage

    @property
    def stage_label(self) -> str:
        """Human-readable label for the current stage (e.g. ``"2v1"``)."""
        return STAGE_LABELS[self._stage]

    @property
    def is_final_stage(self) -> bool:
        """``True`` when the scheduler is already at the last stage (2v2)."""
        return self._stage == CurriculumStage.STAGE_2V2

    @property
    def total_episodes(self) -> int:
        """Total episodes recorded since creation."""
        return self._total_episodes

    # ------------------------------------------------------------------
    # Episode tracking
    # ------------------------------------------------------------------

    def record_episode(self, win: bool) -> None:
        """Record the outcome of a completed episode.

        Parameters
        ----------
        win:
            ``True`` if the Blue team won the episode, ``False`` otherwise.
        """
        self._outcomes.append(bool(win))
        self._total_episodes += 1

    def win_rate(self) -> float:
        """Return the rolling win rate over the last ``win_rate_window`` episodes.

        Returns ``0.0`` if no episodes have been recorded yet.
        """
        if not self._outcomes:
            return 0.0
        return sum(self._outcomes) / len(self._outcomes)

    # ------------------------------------------------------------------
    # Promotion logic
    # ------------------------------------------------------------------

    def should_promote(self) -> bool:
        """Return ``True`` if promotion criteria are met.

        Criteria:
        * At least ``win_rate_window`` episodes have been recorded since the
          last promotion (or since creation).
        * The rolling win rate meets or exceeds ``promote_threshold``.
        * The current stage is not already the final stage.
        """
        if self.is_final_stage:
            return False
        if len(self._outcomes) < self.win_rate_window:
            return False
        return self.win_rate() >= self.promote_threshold

    def promote(self) -> CurriculumStage:
        """Advance to the next curriculum stage and reset the outcome window.

        Returns
        -------
        The new :class:`CurriculumStage` after promotion.

        Raises
        ------
        RuntimeError
            If called when already at the final stage.
        """
        if self.is_final_stage:
            raise RuntimeError(
                "Cannot promote past the final curriculum stage "
                f"({self.stage_label})."
            )

        old_stage = self._stage
        self._stage = CurriculumStage(int(self._stage) + 1)
        # Reset rolling window so the new stage's win rate is measured fresh.
        self._outcomes.clear()

        log.info(
            "Curriculum promoted: %s%s after %d total episodes",
            STAGE_LABELS[old_stage],
            STAGE_LABELS[self._stage],
            self._total_episodes,
        )
        return self._stage

    def env_kwargs(self) -> dict:
        """Return the environment kwargs for the current stage.

        These match the ``n_blue``/``n_red`` values in the scenario YAML files
        under ``configs/scenarios/``.
        """
        return dict(STAGE_ENV_KWARGS[self._stage])

    def wandb_metrics(self) -> dict:
        """Return a dict of W&B metrics for the current state.

        Keys:
        * ``curriculum/stage`` — integer stage index
        * ``curriculum/stage_label`` — e.g. ``"2v1"``
        * ``curriculum/win_rate`` — rolling win rate in ``[0, 1]``
        * ``curriculum/total_episodes`` — cumulative episode count
        """
        return {
            "curriculum/stage": int(self._stage),
            "curriculum/stage_label": self.stage_label,
            "curriculum/win_rate": self.win_rate(),
            "curriculum/total_episodes": self._total_episodes,
        }

    def log_promotion_event(self, total_steps: int, wandb_run: object = None) -> None:
        """Log a curriculum stage-promotion event to W&B.

        Parameters
        ----------
        total_steps:
            Current total environment step count (used as the W&B x-axis).
        wandb_run:
            An active ``wandb.run`` object.  When ``None`` the event is only
            written to the Python logger.
        """
        metrics = self.wandb_metrics()
        metrics["curriculum/promotion_step"] = total_steps
        log.info(
            "Curriculum stage transition → %s at step %d (win_rate=%.3f, "
            "total_episodes=%d)",
            self.stage_label,
            total_steps,
            metrics["curriculum/win_rate"],
            self._total_episodes,
        )
        if wandb_run is not None:
            try:
                import wandb  # local import to keep module importable without wandb

                wandb.log(metrics, step=total_steps)
            except Exception as exc:  # pragma: no cover
                log.warning("W&B logging failed during promotion: %s", exc)

is_final_stage property

True when the scheduler is already at the last stage (2v2).

stage property

The current curriculum stage.

stage_label property

Human-readable label for the current stage (e.g. "2v1").

total_episodes property

Total episodes recorded since creation.

env_kwargs()

Return the environment kwargs for the current stage.

These match the n_blue/n_red values in the scenario YAML files under configs/scenarios/.

Source code in training/curriculum_scheduler.py
def env_kwargs(self) -> dict:
    """Return the environment kwargs for the current stage.

    These match the ``n_blue``/``n_red`` values in the scenario YAML files
    under ``configs/scenarios/``.
    """
    return dict(STAGE_ENV_KWARGS[self._stage])

log_promotion_event(total_steps, wandb_run=None)

Log a curriculum stage-promotion event to W&B.

Parameters:

Name Type Description Default
total_steps int

Current total environment step count (used as the W&B x-axis).

required
wandb_run object

An active wandb.run object. When None the event is only written to the Python logger.

None
Source code in training/curriculum_scheduler.py
def log_promotion_event(self, total_steps: int, wandb_run: object = None) -> None:
    """Log a curriculum stage-promotion event to W&B.

    Parameters
    ----------
    total_steps:
        Current total environment step count (used as the W&B x-axis).
    wandb_run:
        An active ``wandb.run`` object.  When ``None`` the event is only
        written to the Python logger.
    """
    metrics = self.wandb_metrics()
    metrics["curriculum/promotion_step"] = total_steps
    log.info(
        "Curriculum stage transition → %s at step %d (win_rate=%.3f, "
        "total_episodes=%d)",
        self.stage_label,
        total_steps,
        metrics["curriculum/win_rate"],
        self._total_episodes,
    )
    if wandb_run is not None:
        try:
            import wandb  # local import to keep module importable without wandb

            wandb.log(metrics, step=total_steps)
        except Exception as exc:  # pragma: no cover
            log.warning("W&B logging failed during promotion: %s", exc)

promote()

Advance to the next curriculum stage and reset the outcome window.

Returns:

Type Description
The new :class:`CurriculumStage` after promotion.

Raises:

Type Description
RuntimeError

If called when already at the final stage.

Source code in training/curriculum_scheduler.py
def promote(self) -> CurriculumStage:
    """Advance to the next curriculum stage and reset the outcome window.

    Returns
    -------
    The new :class:`CurriculumStage` after promotion.

    Raises
    ------
    RuntimeError
        If called when already at the final stage.
    """
    if self.is_final_stage:
        raise RuntimeError(
            "Cannot promote past the final curriculum stage "
            f"({self.stage_label})."
        )

    old_stage = self._stage
    self._stage = CurriculumStage(int(self._stage) + 1)
    # Reset rolling window so the new stage's win rate is measured fresh.
    self._outcomes.clear()

    log.info(
        "Curriculum promoted: %s%s after %d total episodes",
        STAGE_LABELS[old_stage],
        STAGE_LABELS[self._stage],
        self._total_episodes,
    )
    return self._stage

record_episode(win)

Record the outcome of a completed episode.

Parameters:

Name Type Description Default
win bool

True if the Blue team won the episode, False otherwise.

required
Source code in training/curriculum_scheduler.py
def record_episode(self, win: bool) -> None:
    """Record the outcome of a completed episode.

    Parameters
    ----------
    win:
        ``True`` if the Blue team won the episode, ``False`` otherwise.
    """
    self._outcomes.append(bool(win))
    self._total_episodes += 1

should_promote()

Return True if promotion criteria are met.

Criteria: * At least win_rate_window episodes have been recorded since the last promotion (or since creation). * The rolling win rate meets or exceeds promote_threshold. * The current stage is not already the final stage.

Source code in training/curriculum_scheduler.py
def should_promote(self) -> bool:
    """Return ``True`` if promotion criteria are met.

    Criteria:
    * At least ``win_rate_window`` episodes have been recorded since the
      last promotion (or since creation).
    * The rolling win rate meets or exceeds ``promote_threshold``.
    * The current stage is not already the final stage.
    """
    if self.is_final_stage:
        return False
    if len(self._outcomes) < self.win_rate_window:
        return False
    return self.win_rate() >= self.promote_threshold

wandb_metrics()

Return a dict of W&B metrics for the current state.

Keys: * curriculum/stage — integer stage index * curriculum/stage_label — e.g. "2v1" * curriculum/win_rate — rolling win rate in [0, 1] * curriculum/total_episodes — cumulative episode count

Source code in training/curriculum_scheduler.py
def wandb_metrics(self) -> dict:
    """Return a dict of W&B metrics for the current state.

    Keys:
    * ``curriculum/stage`` — integer stage index
    * ``curriculum/stage_label`` — e.g. ``"2v1"``
    * ``curriculum/win_rate`` — rolling win rate in ``[0, 1]``
    * ``curriculum/total_episodes`` — cumulative episode count
    """
    return {
        "curriculum/stage": int(self._stage),
        "curriculum/stage_label": self.stage_label,
        "curriculum/win_rate": self.win_rate(),
        "curriculum/total_episodes": self._total_episodes,
    }

win_rate()

Return the rolling win rate over the last win_rate_window episodes.

Returns 0.0 if no episodes have been recorded yet.

Source code in training/curriculum_scheduler.py
def win_rate(self) -> float:
    """Return the rolling win rate over the last ``win_rate_window`` episodes.

    Returns ``0.0`` if no episodes have been recorded yet.
    """
    if not self._outcomes:
        return 0.0
    return sum(self._outcomes) / len(self._outcomes)

training.curriculum_scheduler.CurriculumStage

Bases: IntEnum

Ordered curriculum stages.

The integer value doubles as the stage index used when indexing into the STAGE_ENV_KWARGS mapping defined in :data:STAGE_ENV_KWARGS.

Source code in training/curriculum_scheduler.py
class CurriculumStage(IntEnum):
    """Ordered curriculum stages.

    The integer value doubles as the stage index used when indexing into the
    ``STAGE_ENV_KWARGS`` mapping defined in :data:`STAGE_ENV_KWARGS`.
    """

    STAGE_1V1 = 0  #: Bootstrap stage — 1 Blue vs 1 Red (frozen v1 checkpoint)
    STAGE_2V1 = 1  #: Asymmetric advantage — 2 Blue vs 1 Red
    STAGE_2V2 = 2  #: Full cooperative challenge — 2 Blue vs 2 Red

training.curriculum_scheduler.load_v1_weights_into_mappo(v1_checkpoint_path, mappo_policy, *, strict=False)

Copy shared-trunk weights from a SB3 PPO v1 checkpoint into a MAPPO actor.

SB3 BattalionMlpPolicy (net_arch=[128, 128]) stores its actor network under mlp_extractor.policy_net.* (even indices are Linear layers; odd indices are activation functions with no parameters) and the final action head under action_net.*.

MAPPO MAPPOActor (hidden_sizes=(128, 64)) stores its trunk under actor.trunk.* where Linear layers are at indices 0, 3, 6, … (interleaved with LayerNorm at 1, 4, … and Tanh at 2, 5, …) and the action head under actor.action_mean.*.

Mapping strategy ~~~~~~~~~~~~~~~~ Linear layers are matched positionally — the i-th Linear layer in mlp_extractor.policy_net maps to the i-th Linear layer in actor.trunk. Layers whose weight shapes do not match are silently skipped (or raise ValueError when strict=True).

Additionally log_stdactor.log_std and action_net.*actor.action_mean.* are transferred when shapes match.

Parameters:

Name Type Description Default
v1_checkpoint_path str | Path

Path to the SB3 .zip checkpoint produced by training/train.py.

required
mappo_policy Module

A :class:~models.mappo_policy.MAPPOPolicy instance whose actor trunk will be warm-started.

required
strict bool

When True, raises on any shape or key mismatch. When False (default) mismatches are logged as warnings and skipped, allowing partial weight transfer when the 1v1 obs-dim differs from 2v2.

False

Returns:

Type Description
A dict with keys ``"loaded"`` (list of transferred MAPPO layer names) and
``"skipped"`` (list of skipped layer names).
Source code in training/curriculum_scheduler.py
def load_v1_weights_into_mappo(
    v1_checkpoint_path: str | Path,
    mappo_policy: torch.nn.Module,
    *,
    strict: bool = False,
) -> dict:
    """Copy shared-trunk weights from a SB3 PPO v1 checkpoint into a MAPPO actor.

    SB3 ``BattalionMlpPolicy`` (``net_arch=[128, 128]``) stores its actor
    network under ``mlp_extractor.policy_net.*`` (even indices are Linear
    layers; odd indices are activation functions with no parameters) and the
    final action head under ``action_net.*``.

    MAPPO ``MAPPOActor`` (``hidden_sizes=(128, 64)``) stores its trunk under
    ``actor.trunk.*`` where Linear layers are at indices 0, 3, 6, …
    (interleaved with LayerNorm at 1, 4, … and Tanh at 2, 5, …) and the
    action head under ``actor.action_mean.*``.

    Mapping strategy
    ~~~~~~~~~~~~~~~~
    Linear layers are matched *positionally* — the *i*-th Linear layer in
    ``mlp_extractor.policy_net`` maps to the *i*-th Linear layer in
    ``actor.trunk``.  Layers whose weight shapes do not match are silently
    skipped (or raise ``ValueError`` when ``strict=True``).

    Additionally ``log_std`` → ``actor.log_std`` and
    ``action_net.*`` → ``actor.action_mean.*`` are transferred when shapes
    match.

    Parameters
    ----------
    v1_checkpoint_path:
        Path to the SB3 ``.zip`` checkpoint produced by ``training/train.py``.
    mappo_policy:
        A :class:`~models.mappo_policy.MAPPOPolicy` instance whose actor
        trunk will be warm-started.
    strict:
        When ``True``, raises on any shape or key mismatch.  When ``False``
        (default) mismatches are logged as warnings and skipped, allowing
        partial weight transfer when the 1v1 obs-dim differs from 2v2.

    Returns
    -------
    A dict with keys ``"loaded"`` (list of transferred MAPPO layer names) and
    ``"skipped"`` (list of skipped layer names).
    """
    import io
    import re
    import zipfile

    v1_path = Path(v1_checkpoint_path)
    if not v1_path.exists():
        raise FileNotFoundError(f"v1 checkpoint not found: {v1_path}")

    # SB3 saves the PyTorch state dict as "policy.pth" inside the zip.
    with zipfile.ZipFile(v1_path, "r") as zf:
        if "policy.pth" not in zf.namelist():
            raise ValueError(
                f"Expected 'policy.pth' inside {v1_path}; "
                f"found: {zf.namelist()}"
            )
        with zf.open("policy.pth") as f:
            v1_state: dict = torch.load(io.BytesIO(f.read()), map_location="cpu")

    transferred: list[str] = []
    skipped: list[str] = []

    mappo_state = mappo_policy.state_dict()
    new_state: dict = {}

    def _try_transfer(v1_key: str, mappo_key: str, v1_tensor: torch.Tensor) -> None:
        """Attempt to transfer one tensor; update transferred/skipped lists."""
        if mappo_key in mappo_state:
            if mappo_state[mappo_key].shape == v1_tensor.shape:
                new_state[mappo_key] = v1_tensor
                transferred.append(mappo_key)
            else:
                msg = (
                    f"Shape mismatch for {mappo_key}: "
                    f"v1={v1_tensor.shape} mappo={mappo_state[mappo_key].shape}"
                )
                if strict:
                    raise ValueError(msg)
                log.warning("load_v1_weights_into_mappo: skipping %s", msg)
                skipped.append(mappo_key)
        else:
            skipped.append(mappo_key)

    # ------------------------------------------------------------------
    # 1. Transfer mlp_extractor.policy_net Linear layers → actor.trunk
    #    positionally (i-th SB3 Linear → i-th MAPPO trunk Linear).
    #
    #    SB3 indices: 0, 2, 4, ... are Linear (odd indices are Tanh, no params).
    #    MAPPO indices: 0, 3, 6, ... are Linear (LayerNorm 1,4,… Tanh 2,5,…).
    #    We detect Linear layers by their weight tensor being 2-D.
    # ------------------------------------------------------------------
    pnet_pat = re.compile(r"^mlp_extractor\.policy_net\.(\d+)\.(weight|bias)$")
    trunk_pat = re.compile(r"^actor\.trunk\.(\d+)\.(weight|bias)$")

    # Collect SB3 policy_net layers indexed by their Sequential index.
    pnet_layers: dict[int, dict[str, tuple[str, torch.Tensor]]] = {}
    for k, v in v1_state.items():
        m = pnet_pat.match(k)
        if m:
            idx, kind = int(m.group(1)), m.group(2)
            pnet_layers.setdefault(idx, {})[kind] = (k, v)

    # Keep only Linear layers (weight is 2-D).
    v1_trunk_linears: list[dict[str, tuple[str, torch.Tensor]]] = []
    for idx in sorted(pnet_layers):
        entry = pnet_layers[idx]
        w_key, w_tensor = entry.get("weight", (None, None))
        if w_tensor is not None and w_tensor.dim() == 2:
            v1_trunk_linears.append(entry)

    # Collect MAPPO actor.trunk Linear layers (weight is 2-D).
    mappo_trunk_layers: dict[int, dict[str, tuple[str, torch.Tensor]]] = {}
    for k, v in mappo_state.items():
        m = trunk_pat.match(k)
        if m:
            idx, kind = int(m.group(1)), m.group(2)
            mappo_trunk_layers.setdefault(idx, {})[kind] = (k, v)

    mappo_trunk_linears: list[dict[str, tuple[str, torch.Tensor]]] = []
    for idx in sorted(mappo_trunk_layers):
        entry = mappo_trunk_layers[idx]
        _, w_tensor = entry.get("weight", (None, None))
        if w_tensor is not None and w_tensor.dim() == 2:
            mappo_trunk_linears.append(entry)

    # Match by position.
    for i, v1_entry in enumerate(v1_trunk_linears):
        if i >= len(mappo_trunk_linears):
            for kind, (k, _) in v1_entry.items():
                skipped.append(k)
            continue
        mappo_entry = mappo_trunk_linears[i]
        for kind in ("weight", "bias"):
            if kind in v1_entry and kind in mappo_entry:
                v1_key, v1_tensor = v1_entry[kind]
                mappo_key, _ = mappo_entry[kind]
                _try_transfer(v1_key, mappo_key, v1_tensor)

    # ------------------------------------------------------------------
    # 2. Transfer action_net.* → actor.action_mean.*
    # ------------------------------------------------------------------
    action_pat = re.compile(r"^action_net\.(weight|bias)$")
    for k, v in v1_state.items():
        m = action_pat.match(k)
        if m:
            kind = m.group(1)
            mappo_key = f"actor.action_mean.{kind}"
            _try_transfer(k, mappo_key, v)

    # ------------------------------------------------------------------
    # 3. Transfer log_std → actor.log_std
    # ------------------------------------------------------------------
    if "log_std" in v1_state:
        _try_transfer("log_std", "actor.log_std", v1_state["log_std"])

    if new_state:
        mappo_state.update(new_state)
        mappo_policy.load_state_dict(mappo_state, strict=False)
        log.info(
            "load_v1_weights_into_mappo: transferred %d layers, skipped %d layers",
            len(transferred),
            len(skipped),
        )
    else:
        log.warning(
            "load_v1_weights_into_mappo: no layers transferred from %s", v1_path
        )

    return {"loaded": transferred, "skipped": skipped}

Policy Registry

training.policy_registry.PolicyRegistry

Versioned policy registry backed by a JSON manifest.

Parameters:

Name Type Description Default
path Union[str, Path, None]

Path to the JSON manifest file. The parent directory is created on :meth:save. Pass None for an in-memory-only registry.

'checkpoints/policy_registry.json'
Source code in training/policy_registry.py
class PolicyRegistry:
    """Versioned policy registry backed by a JSON manifest.

    Parameters
    ----------
    path:
        Path to the JSON manifest file.  The parent directory is created on
        :meth:`save`.  Pass ``None`` for an in-memory-only registry.
    """

    def __init__(
        self,
        path: Union[str, Path, None] = "checkpoints/policy_registry.json",
    ) -> None:
        self._path: Path | None = Path(path) if path is not None else None
        self._entries: list[PolicyEntry] = []
        if self._path is not None and self._path.exists():
            self._load()

    # ------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------

    @property
    def can_save(self) -> bool:
        """``True`` when the registry has a backing file path."""
        return self._path is not None

    # ------------------------------------------------------------------
    # CRUD
    # ------------------------------------------------------------------

    def register(
        self,
        echelon: Union[Echelon, str],
        version: str,
        path: Union[str, Path],
        run_id: Optional[str] = None,
        overwrite: bool = False,
    ) -> PolicyEntry:
        """Register a new policy checkpoint.

        Parameters
        ----------
        echelon:
            The HRL echelon: ``"battalion"``, ``"brigade"``, or
            ``"division"`` (or the :class:`Echelon` enum value).
        version:
            Caller-assigned version string, e.g. ``"v2_final"``.
        path:
            File-system path to the checkpoint.
        run_id:
            Optional W&B run ID linked to this checkpoint.
        overwrite:
            When ``True`` an existing entry with the same echelon+version
            is replaced.  When ``False`` (default) a :exc:`ValueError` is
            raised if the entry already exists.

        Returns
        -------
        PolicyEntry
            The newly created entry.

        Raises
        ------
        ValueError
            If an entry with the same echelon+version already exists and
            *overwrite* is ``False``, or if *echelon* is invalid.
        """
        echelon_str = _echelon_str(echelon)
        path_str = str(path)

        existing_idx = self._find_index(echelon_str, version)
        if existing_idx is not None:
            if not overwrite:
                raise ValueError(
                    f"Entry already exists for echelon='{echelon_str}' "
                    f"version='{version}'. Pass overwrite=True to replace it."
                )
            self._entries.pop(existing_idx)

        entry = PolicyEntry(
            echelon=echelon_str,
            version=version,
            path=path_str,
            run_id=run_id,
        )
        self._entries.append(entry)
        log.info(
            "PolicyRegistry: registered %s/%s%s (run_id=%s)",
            echelon_str,
            version,
            path_str,
            run_id,
        )
        return entry

    def get(self, echelon: Union[Echelon, str], version: str) -> PolicyEntry:
        """Return the :class:`PolicyEntry` for *echelon* / *version*.

        Parameters
        ----------
        echelon:
            Echelon name or :class:`Echelon` enum.
        version:
            Version string.

        Returns
        -------
        PolicyEntry

        Raises
        ------
        KeyError
            If no matching entry is found.
        """
        echelon_str = _echelon_str(echelon)
        idx = self._find_index(echelon_str, version)
        if idx is None:
            raise KeyError(
                f"No policy registered for echelon='{echelon_str}' "
                f"version='{version}'."
            )
        return self._entries[idx]

    def remove(self, echelon: Union[Echelon, str], version: str) -> None:
        """Remove an entry from the registry.

        Parameters
        ----------
        echelon:
            Echelon name or :class:`Echelon` enum.
        version:
            Version string.

        Raises
        ------
        KeyError
            If no matching entry is found.
        """
        echelon_str = _echelon_str(echelon)
        idx = self._find_index(echelon_str, version)
        if idx is None:
            raise KeyError(
                f"No policy registered for echelon='{echelon_str}' "
                f"version='{version}'."
            )
        self._entries.pop(idx)
        log.info(
            "PolicyRegistry: removed %s/%s", echelon_str, version
        )

    def list(
        self, echelon: Union[Echelon, str, None] = None
    ) -> List[PolicyEntry]:
        """Return all registered entries, optionally filtered by echelon.

        Parameters
        ----------
        echelon:
            When provided, only entries for this echelon are returned.

        Returns
        -------
        list[PolicyEntry]
            A fresh list; mutating it does not affect the registry.
        """
        if echelon is None:
            return list(self._entries)
        echelon_str = _echelon_str(echelon)
        return [e for e in self._entries if e.echelon == echelon_str]

    # ------------------------------------------------------------------
    # Policy loading
    # ------------------------------------------------------------------

    def load(
        self,
        echelon: Union[Echelon, str],
        version: str,
        device: str = "cpu",
        **mappo_kwargs: Any,
    ) -> Any:
        """Load and return a frozen policy for the given echelon / version.

        The checkpoint format is inferred from the echelon:

        * ``battalion`` → MAPPO ``.pt`` checkpoint loaded via
          :func:`~training.utils.freeze_policy.load_and_freeze_mappo`.
          *mappo_kwargs* (``obs_dim``, ``action_dim``, ``state_dim``,
          ``n_agents``) are forwarded to that function and are **required**
          for battalion policies.
        * ``brigade`` / ``division`` → SB3 PPO ``.zip`` checkpoint loaded
          via :func:`~training.utils.freeze_policy.load_and_freeze_sb3`.

        Parameters
        ----------
        echelon:
            Echelon name or :class:`Echelon` enum.
        version:
            Version string.
        device:
            PyTorch device string (default ``"cpu"``).
        **mappo_kwargs:
            Extra keyword arguments forwarded to
            :func:`~training.utils.freeze_policy.load_and_freeze_mappo` when
            *echelon* is ``"battalion"``.  Must include ``obs_dim``,
            ``action_dim``, ``state_dim``, and optionally ``n_agents``.

        Returns
        -------
        Any
            A frozen :class:`~models.mappo_policy.MAPPOPolicy` (battalion)
            or a frozen SB3 ``PPO`` model (brigade / division).

        Raises
        ------
        KeyError
            If no entry is registered for this echelon+version.
        FileNotFoundError
            If the checkpoint file does not exist.
        """
        entry = self.get(echelon, version)
        echelon_enum = Echelon.from_str(entry.echelon)

        checkpoint_path = Path(entry.path)

        if echelon_enum == Echelon.BATTALION:
            required = ("obs_dim", "action_dim", "state_dim")
            missing = [k for k in required if k not in mappo_kwargs]
            if missing:
                raise ValueError(
                    f"Loading a battalion policy requires keyword arguments: "
                    f"{', '.join(missing)}. "
                    "Pass them as keyword arguments to load()."
                )

        from training.utils.freeze_policy import (
            load_and_freeze_mappo,
            load_and_freeze_sb3,
        )

        if echelon_enum == Echelon.BATTALION:
            return load_and_freeze_mappo(
                checkpoint_path=checkpoint_path,
                device=device,
                **mappo_kwargs,
            )

        # brigade or division → SB3 PPO
        return load_and_freeze_sb3(checkpoint_path=checkpoint_path, device=device)

    # ------------------------------------------------------------------
    # Persistence
    # ------------------------------------------------------------------

    def save(self) -> None:
        """Persist the registry to its JSON manifest file.

        Raises
        ------
        ValueError
            If the registry was created without a file path.
        """
        if self._path is None:
            raise ValueError(
                "Cannot save: this PolicyRegistry was created without a "
                "file path.  Pass a path to the constructor."
            )
        self._path.parent.mkdir(parents=True, exist_ok=True)
        data: dict = {
            "entries": [
                {
                    "echelon": e.echelon,
                    "version": e.version,
                    "path": e.path,
                    "run_id": e.run_id,
                }
                for e in self._entries
            ]
        }
        with open(self._path, "w", encoding="utf-8") as fh:
            json.dump(data, fh, indent=2, sort_keys=True)
        log.info("PolicyRegistry: saved %d entries to %s", len(self._entries), self._path)

    def _load(self) -> None:
        """Load entries from the existing JSON manifest."""
        assert self._path is not None
        with open(self._path, encoding="utf-8") as fh:
            try:
                data = json.load(fh)
            except json.JSONDecodeError as exc:
                raise ValueError(
                    f"PolicyRegistry: failed to parse JSON from "
                    f"'{self._path}': {exc}"
                ) from exc
        raw_entries = data.get("entries", [])
        for item in raw_entries:
            try:
                entry = PolicyEntry(
                    echelon=str(item["echelon"]),
                    version=str(item["version"]),
                    path=str(item["path"]),
                    run_id=item.get("run_id"),
                )
            except (KeyError, TypeError) as exc:
                raise ValueError(
                    f"PolicyRegistry: malformed entry in '{self._path}': "
                    f"{item!r}{exc}"
                ) from exc
            self._entries.append(entry)
        log.info(
            "PolicyRegistry: loaded %d entries from %s",
            len(self._entries),
            self._path,
        )

    def __repr__(self) -> str:  # pragma: no cover
        return (
            f"PolicyRegistry(path={self._path!r}, "
            f"n_entries={len(self._entries)})"
        )

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _find_index(self, echelon: str, version: str) -> Optional[int]:
        """Return the index of the matching entry or ``None``."""
        for i, e in enumerate(self._entries):
            if e.echelon == echelon and e.version == version:
                return i
        return None

can_save property

True when the registry has a backing file path.

get(echelon, version)

Return the :class:PolicyEntry for echelon / version.

Parameters:

Name Type Description Default
echelon Union[Echelon, str]

Echelon name or :class:Echelon enum.

required
version str

Version string.

required

Returns:

Type Description
PolicyEntry

Raises:

Type Description
KeyError

If no matching entry is found.

Source code in training/policy_registry.py
def get(self, echelon: Union[Echelon, str], version: str) -> PolicyEntry:
    """Return the :class:`PolicyEntry` for *echelon* / *version*.

    Parameters
    ----------
    echelon:
        Echelon name or :class:`Echelon` enum.
    version:
        Version string.

    Returns
    -------
    PolicyEntry

    Raises
    ------
    KeyError
        If no matching entry is found.
    """
    echelon_str = _echelon_str(echelon)
    idx = self._find_index(echelon_str, version)
    if idx is None:
        raise KeyError(
            f"No policy registered for echelon='{echelon_str}' "
            f"version='{version}'."
        )
    return self._entries[idx]

list(echelon=None)

Return all registered entries, optionally filtered by echelon.

Parameters:

Name Type Description Default
echelon Union[Echelon, str, None]

When provided, only entries for this echelon are returned.

None

Returns:

Type Description
list[PolicyEntry]

A fresh list; mutating it does not affect the registry.

Source code in training/policy_registry.py
def list(
    self, echelon: Union[Echelon, str, None] = None
) -> List[PolicyEntry]:
    """Return all registered entries, optionally filtered by echelon.

    Parameters
    ----------
    echelon:
        When provided, only entries for this echelon are returned.

    Returns
    -------
    list[PolicyEntry]
        A fresh list; mutating it does not affect the registry.
    """
    if echelon is None:
        return list(self._entries)
    echelon_str = _echelon_str(echelon)
    return [e for e in self._entries if e.echelon == echelon_str]

load(echelon, version, device='cpu', **mappo_kwargs)

Load and return a frozen policy for the given echelon / version.

The checkpoint format is inferred from the echelon:

  • battalion → MAPPO .pt checkpoint loaded via :func:~training.utils.freeze_policy.load_and_freeze_mappo. mappo_kwargs (obs_dim, action_dim, state_dim, n_agents) are forwarded to that function and are required for battalion policies.
  • brigade / division → SB3 PPO .zip checkpoint loaded via :func:~training.utils.freeze_policy.load_and_freeze_sb3.

Parameters:

Name Type Description Default
echelon Union[Echelon, str]

Echelon name or :class:Echelon enum.

required
version str

Version string.

required
device str

PyTorch device string (default "cpu").

'cpu'
**mappo_kwargs Any

Extra keyword arguments forwarded to :func:~training.utils.freeze_policy.load_and_freeze_mappo when echelon is "battalion". Must include obs_dim, action_dim, state_dim, and optionally n_agents.

{}

Returns:

Type Description
Any

A frozen :class:~models.mappo_policy.MAPPOPolicy (battalion) or a frozen SB3 PPO model (brigade / division).

Raises:

Type Description
KeyError

If no entry is registered for this echelon+version.

FileNotFoundError

If the checkpoint file does not exist.

Source code in training/policy_registry.py
def load(
    self,
    echelon: Union[Echelon, str],
    version: str,
    device: str = "cpu",
    **mappo_kwargs: Any,
) -> Any:
    """Load and return a frozen policy for the given echelon / version.

    The checkpoint format is inferred from the echelon:

    * ``battalion`` → MAPPO ``.pt`` checkpoint loaded via
      :func:`~training.utils.freeze_policy.load_and_freeze_mappo`.
      *mappo_kwargs* (``obs_dim``, ``action_dim``, ``state_dim``,
      ``n_agents``) are forwarded to that function and are **required**
      for battalion policies.
    * ``brigade`` / ``division`` → SB3 PPO ``.zip`` checkpoint loaded
      via :func:`~training.utils.freeze_policy.load_and_freeze_sb3`.

    Parameters
    ----------
    echelon:
        Echelon name or :class:`Echelon` enum.
    version:
        Version string.
    device:
        PyTorch device string (default ``"cpu"``).
    **mappo_kwargs:
        Extra keyword arguments forwarded to
        :func:`~training.utils.freeze_policy.load_and_freeze_mappo` when
        *echelon* is ``"battalion"``.  Must include ``obs_dim``,
        ``action_dim``, ``state_dim``, and optionally ``n_agents``.

    Returns
    -------
    Any
        A frozen :class:`~models.mappo_policy.MAPPOPolicy` (battalion)
        or a frozen SB3 ``PPO`` model (brigade / division).

    Raises
    ------
    KeyError
        If no entry is registered for this echelon+version.
    FileNotFoundError
        If the checkpoint file does not exist.
    """
    entry = self.get(echelon, version)
    echelon_enum = Echelon.from_str(entry.echelon)

    checkpoint_path = Path(entry.path)

    if echelon_enum == Echelon.BATTALION:
        required = ("obs_dim", "action_dim", "state_dim")
        missing = [k for k in required if k not in mappo_kwargs]
        if missing:
            raise ValueError(
                f"Loading a battalion policy requires keyword arguments: "
                f"{', '.join(missing)}. "
                "Pass them as keyword arguments to load()."
            )

    from training.utils.freeze_policy import (
        load_and_freeze_mappo,
        load_and_freeze_sb3,
    )

    if echelon_enum == Echelon.BATTALION:
        return load_and_freeze_mappo(
            checkpoint_path=checkpoint_path,
            device=device,
            **mappo_kwargs,
        )

    # brigade or division → SB3 PPO
    return load_and_freeze_sb3(checkpoint_path=checkpoint_path, device=device)

register(echelon, version, path, run_id=None, overwrite=False)

Register a new policy checkpoint.

Parameters:

Name Type Description Default
echelon Union[Echelon, str]

The HRL echelon: "battalion", "brigade", or "division" (or the :class:Echelon enum value).

required
version str

Caller-assigned version string, e.g. "v2_final".

required
path Union[str, Path]

File-system path to the checkpoint.

required
run_id Optional[str]

Optional W&B run ID linked to this checkpoint.

None
overwrite bool

When True an existing entry with the same echelon+version is replaced. When False (default) a :exc:ValueError is raised if the entry already exists.

False

Returns:

Type Description
PolicyEntry

The newly created entry.

Raises:

Type Description
ValueError

If an entry with the same echelon+version already exists and overwrite is False, or if echelon is invalid.

Source code in training/policy_registry.py
def register(
    self,
    echelon: Union[Echelon, str],
    version: str,
    path: Union[str, Path],
    run_id: Optional[str] = None,
    overwrite: bool = False,
) -> PolicyEntry:
    """Register a new policy checkpoint.

    Parameters
    ----------
    echelon:
        The HRL echelon: ``"battalion"``, ``"brigade"``, or
        ``"division"`` (or the :class:`Echelon` enum value).
    version:
        Caller-assigned version string, e.g. ``"v2_final"``.
    path:
        File-system path to the checkpoint.
    run_id:
        Optional W&B run ID linked to this checkpoint.
    overwrite:
        When ``True`` an existing entry with the same echelon+version
        is replaced.  When ``False`` (default) a :exc:`ValueError` is
        raised if the entry already exists.

    Returns
    -------
    PolicyEntry
        The newly created entry.

    Raises
    ------
    ValueError
        If an entry with the same echelon+version already exists and
        *overwrite* is ``False``, or if *echelon* is invalid.
    """
    echelon_str = _echelon_str(echelon)
    path_str = str(path)

    existing_idx = self._find_index(echelon_str, version)
    if existing_idx is not None:
        if not overwrite:
            raise ValueError(
                f"Entry already exists for echelon='{echelon_str}' "
                f"version='{version}'. Pass overwrite=True to replace it."
            )
        self._entries.pop(existing_idx)

    entry = PolicyEntry(
        echelon=echelon_str,
        version=version,
        path=path_str,
        run_id=run_id,
    )
    self._entries.append(entry)
    log.info(
        "PolicyRegistry: registered %s/%s%s (run_id=%s)",
        echelon_str,
        version,
        path_str,
        run_id,
    )
    return entry

remove(echelon, version)

Remove an entry from the registry.

Parameters:

Name Type Description Default
echelon Union[Echelon, str]

Echelon name or :class:Echelon enum.

required
version str

Version string.

required

Raises:

Type Description
KeyError

If no matching entry is found.

Source code in training/policy_registry.py
def remove(self, echelon: Union[Echelon, str], version: str) -> None:
    """Remove an entry from the registry.

    Parameters
    ----------
    echelon:
        Echelon name or :class:`Echelon` enum.
    version:
        Version string.

    Raises
    ------
    KeyError
        If no matching entry is found.
    """
    echelon_str = _echelon_str(echelon)
    idx = self._find_index(echelon_str, version)
    if idx is None:
        raise KeyError(
            f"No policy registered for echelon='{echelon_str}' "
            f"version='{version}'."
        )
    self._entries.pop(idx)
    log.info(
        "PolicyRegistry: removed %s/%s", echelon_str, version
    )

save()

Persist the registry to its JSON manifest file.

Raises:

Type Description
ValueError

If the registry was created without a file path.

Source code in training/policy_registry.py
def save(self) -> None:
    """Persist the registry to its JSON manifest file.

    Raises
    ------
    ValueError
        If the registry was created without a file path.
    """
    if self._path is None:
        raise ValueError(
            "Cannot save: this PolicyRegistry was created without a "
            "file path.  Pass a path to the constructor."
        )
    self._path.parent.mkdir(parents=True, exist_ok=True)
    data: dict = {
        "entries": [
            {
                "echelon": e.echelon,
                "version": e.version,
                "path": e.path,
                "run_id": e.run_id,
            }
            for e in self._entries
        ]
    }
    with open(self._path, "w", encoding="utf-8") as fh:
        json.dump(data, fh, indent=2, sort_keys=True)
    log.info("PolicyRegistry: saved %d entries to %s", len(self._entries), self._path)

training.policy_registry.Echelon

Bases: str, Enum

Supported HRL echelon levels.

Source code in training/policy_registry.py
class Echelon(str, Enum):
    """Supported HRL echelon levels."""

    BATTALION = "battalion"
    BRIGADE = "brigade"
    DIVISION = "division"

    @classmethod
    def from_str(cls, value: str) -> "Echelon":
        """Case-insensitive lookup from string.

        Parameters
        ----------
        value:
            Echelon name, e.g. ``"battalion"``, ``"Brigade"``, or an
            :class:`Echelon` enum value.

        Returns
        -------
        Echelon

        Raises
        ------
        ValueError
            If *value* is not a valid echelon name.
        """
        try:
            return cls(value.lower())
        except ValueError:
            valid = ", ".join(e.value for e in cls)
            raise ValueError(
                f"Invalid echelon '{value}'. Must be one of: {valid}."
            )

from_str(value) classmethod

Case-insensitive lookup from string.

Parameters:

Name Type Description Default
value str

Echelon name, e.g. "battalion", "Brigade", or an :class:Echelon enum value.

required

Returns:

Type Description
Echelon

Raises:

Type Description
ValueError

If value is not a valid echelon name.

Source code in training/policy_registry.py
@classmethod
def from_str(cls, value: str) -> "Echelon":
    """Case-insensitive lookup from string.

    Parameters
    ----------
    value:
        Echelon name, e.g. ``"battalion"``, ``"Brigade"``, or an
        :class:`Echelon` enum value.

    Returns
    -------
    Echelon

    Raises
    ------
    ValueError
        If *value* is not a valid echelon name.
    """
    try:
        return cls(value.lower())
    except ValueError:
        valid = ", ".join(e.value for e in cls)
        raise ValueError(
            f"Invalid echelon '{value}'. Must be one of: {valid}."
        )

training.policy_registry.PolicyEntry

Bases: NamedTuple

Metadata record for a single registered policy checkpoint.

Attributes:

Name Type Description
echelon str

The HRL echelon this checkpoint belongs to.

version str

Caller-assigned version string, e.g. "v2_final" or "step_500k".

path str

File-system path to the checkpoint file.

run_id Optional[str]

Optional W&B run ID associated with this checkpoint.

Source code in training/policy_registry.py
class PolicyEntry(NamedTuple):
    """Metadata record for a single registered policy checkpoint.

    Attributes
    ----------
    echelon:
        The HRL echelon this checkpoint belongs to.
    version:
        Caller-assigned version string, e.g. ``"v2_final"`` or ``"step_500k"``.
    path:
        File-system path to the checkpoint file.
    run_id:
        Optional W&B run ID associated with this checkpoint.
    """

    echelon: str
    version: str
    path: str
    run_id: Optional[str]

    def __str__(self) -> str:
        run = self.run_id or "—"
        return (
            f"{self.echelon:<10s}  {self.version:<20s}  "
            f"{self.path:<50s}  run_id={run}"
        )

Elo Ratings

training.elo.EloRegistry

Persistent Elo rating registry backed by a JSON file.

Stores per-agent ratings and game counts. Scripted baseline opponents ("scripted_l1""scripted_l5", "random") have fixed seed ratings defined in :data:BASELINE_RATINGS and are never modified.

Parameters:

Name Type Description Default
path Union[str, Path, None]

Path to the JSON file used for persistence. The parent directory is created automatically on :meth:save. Pass None to create an in-memory registry that cannot be saved to disk.

'checkpoints/elo_registry.json'
Source code in training/elo.py
class EloRegistry:
    """Persistent Elo rating registry backed by a JSON file.

    Stores per-agent ratings and game counts.  Scripted baseline opponents
    (``"scripted_l1"`` … ``"scripted_l5"``, ``"random"``) have fixed seed
    ratings defined in :data:`BASELINE_RATINGS` and are **never modified**.

    Parameters
    ----------
    path:
        Path to the JSON file used for persistence.  The parent directory
        is created automatically on :meth:`save`.  Pass ``None`` to create
        an in-memory registry that cannot be saved to disk.
    """

    def __init__(
        self,
        path: Union[str, Path, None] = "checkpoints/elo_registry.json",
    ) -> None:
        self._path: Path | None = Path(path) if path is not None else None
        self._ratings: dict[str, float] = {}
        self._game_counts: dict[str, int] = {}
        if self._path is not None and self._path.exists():
            self._load()

    # ------------------------------------------------------------------
    # Read helpers
    # ------------------------------------------------------------------

    def get_rating(self, name: str) -> float:
        """Return the current Elo rating for *name*.

        Falls back to :data:`BASELINE_RATINGS` for scripted opponents, then
        to :data:`DEFAULT_RATING` for completely unknown agents.
        """
        if name in self._ratings:
            return self._ratings[name]
        return BASELINE_RATINGS.get(name, DEFAULT_RATING)

    def get_game_count(self, name: str) -> int:
        """Return the total number of rated games played by *name*."""
        return self._game_counts.get(name, 0)

    def all_ratings(self) -> dict[str, float]:
        """Return a copy of all *stored* ratings (excludes pure baselines)."""
        return dict(self._ratings)

    @property
    def can_save(self) -> bool:
        """``True`` when the registry has a backing file and can be persisted."""
        return self._path is not None

    # ------------------------------------------------------------------
    # Update
    # ------------------------------------------------------------------

    def update(
        self,
        agent: str,
        opponent: str,
        outcome: float,
        n_games: int = 1,
    ) -> float:
        """Update the Elo rating of *agent* after a batch of *n_games*.

        The *outcome* is the average score per game:

        * ``1.0`` — all wins
        * ``0.5`` — all draws
        * ``0.0`` — all losses

        Scripted baseline opponents' ratings are **never modified**.

        Parameters
        ----------
        agent:
            Identifier for the agent whose rating is updated.
        opponent:
            Identifier of the opponent played against.
        outcome:
            Average per-game score in ``[0, 1]``.
        n_games:
            Number of games in this batch (used to advance the game-count
            counter; the K-factor is evaluated at the *pre-update* count).

        Returns
        -------
        float
            Elo rating delta for *agent* (positive = rating increased).

        Raises
        ------
        ValueError
            If *outcome* is outside ``[0, 1]``, *n_games* < 1, or *agent*
            is a key in :data:`BASELINE_RATINGS` (baselines are immutable).
        """
        if not 0.0 <= outcome <= 1.0:
            raise ValueError(
                f"outcome must be in [0, 1], got {outcome!r}."
            )
        if n_games < 1:
            raise ValueError(
                f"n_games must be >= 1, got {n_games!r}."
            )
        if agent in BASELINE_RATINGS:
            raise ValueError(
                f"Cannot update rating for baseline opponent '{agent}'. "
                "Baseline ratings are fixed and cannot be modified."
            )

        r_agent = self.get_rating(agent)
        r_opponent = self.get_rating(opponent)
        n_so_far = self.get_game_count(agent)
        k = k_factor(n_so_far)

        expected = expected_score(r_agent, r_opponent)
        delta = k * (outcome - expected)

        self._ratings[agent] = r_agent + delta
        self._game_counts[agent] = n_so_far + n_games
        return delta

    # ------------------------------------------------------------------
    # Persistence
    # ------------------------------------------------------------------

    def save(self) -> None:
        """Persist current ratings and game counts to the JSON file.

        Raises
        ------
        ValueError
            If the registry was created without a file path (``path=None``).
        """
        if self._path is None:
            raise ValueError(
                "Cannot save: this EloRegistry was created without a file path. "
                "Pass a path to the constructor to enable persistence."
            )
        self._path.parent.mkdir(parents=True, exist_ok=True)
        data: dict = {
            "ratings": self._ratings,
            "game_counts": self._game_counts,
        }
        with open(self._path, "w", encoding="utf-8") as fh:
            json.dump(data, fh, indent=2, sort_keys=True)

    def _load(self) -> None:
        """Load ratings from the existing JSON file."""
        with open(self._path, encoding="utf-8") as fh:  # type: ignore[arg-type]
            try:
                data = json.load(fh)
            except json.JSONDecodeError as exc:
                raise ValueError(
                    f"EloRegistry: failed to parse JSON from '{self._path}': {exc}"
                ) from exc
        try:
            self._ratings = {str(k): float(v) for k, v in data.get("ratings", {}).items()}
            self._game_counts = {
                str(k): int(v) for k, v in data.get("game_counts", {}).items()
            }
        except (TypeError, ValueError) as exc:
            raise ValueError(
                f"EloRegistry: invalid data types in '{self._path}': {exc}"
            ) from exc

    def __repr__(self) -> str:  # pragma: no cover
        return (
            f"EloRegistry(path={self._path!r}, "
            f"n_agents={len(self._ratings)})"
        )

can_save property

True when the registry has a backing file and can be persisted.

all_ratings()

Return a copy of all stored ratings (excludes pure baselines).

Source code in training/elo.py
def all_ratings(self) -> dict[str, float]:
    """Return a copy of all *stored* ratings (excludes pure baselines)."""
    return dict(self._ratings)

get_game_count(name)

Return the total number of rated games played by name.

Source code in training/elo.py
def get_game_count(self, name: str) -> int:
    """Return the total number of rated games played by *name*."""
    return self._game_counts.get(name, 0)

get_rating(name)

Return the current Elo rating for name.

Falls back to :data:BASELINE_RATINGS for scripted opponents, then to :data:DEFAULT_RATING for completely unknown agents.

Source code in training/elo.py
def get_rating(self, name: str) -> float:
    """Return the current Elo rating for *name*.

    Falls back to :data:`BASELINE_RATINGS` for scripted opponents, then
    to :data:`DEFAULT_RATING` for completely unknown agents.
    """
    if name in self._ratings:
        return self._ratings[name]
    return BASELINE_RATINGS.get(name, DEFAULT_RATING)

save()

Persist current ratings and game counts to the JSON file.

Raises:

Type Description
ValueError

If the registry was created without a file path (path=None).

Source code in training/elo.py
def save(self) -> None:
    """Persist current ratings and game counts to the JSON file.

    Raises
    ------
    ValueError
        If the registry was created without a file path (``path=None``).
    """
    if self._path is None:
        raise ValueError(
            "Cannot save: this EloRegistry was created without a file path. "
            "Pass a path to the constructor to enable persistence."
        )
    self._path.parent.mkdir(parents=True, exist_ok=True)
    data: dict = {
        "ratings": self._ratings,
        "game_counts": self._game_counts,
    }
    with open(self._path, "w", encoding="utf-8") as fh:
        json.dump(data, fh, indent=2, sort_keys=True)

update(agent, opponent, outcome, n_games=1)

Update the Elo rating of agent after a batch of n_games.

The outcome is the average score per game:

  • 1.0 — all wins
  • 0.5 — all draws
  • 0.0 — all losses

Scripted baseline opponents' ratings are never modified.

Parameters:

Name Type Description Default
agent str

Identifier for the agent whose rating is updated.

required
opponent str

Identifier of the opponent played against.

required
outcome float

Average per-game score in [0, 1].

required
n_games int

Number of games in this batch (used to advance the game-count counter; the K-factor is evaluated at the pre-update count).

1

Returns:

Type Description
float

Elo rating delta for agent (positive = rating increased).

Raises:

Type Description
ValueError

If outcome is outside [0, 1], n_games < 1, or agent is a key in :data:BASELINE_RATINGS (baselines are immutable).

Source code in training/elo.py
def update(
    self,
    agent: str,
    opponent: str,
    outcome: float,
    n_games: int = 1,
) -> float:
    """Update the Elo rating of *agent* after a batch of *n_games*.

    The *outcome* is the average score per game:

    * ``1.0`` — all wins
    * ``0.5`` — all draws
    * ``0.0`` — all losses

    Scripted baseline opponents' ratings are **never modified**.

    Parameters
    ----------
    agent:
        Identifier for the agent whose rating is updated.
    opponent:
        Identifier of the opponent played against.
    outcome:
        Average per-game score in ``[0, 1]``.
    n_games:
        Number of games in this batch (used to advance the game-count
        counter; the K-factor is evaluated at the *pre-update* count).

    Returns
    -------
    float
        Elo rating delta for *agent* (positive = rating increased).

    Raises
    ------
    ValueError
        If *outcome* is outside ``[0, 1]``, *n_games* < 1, or *agent*
        is a key in :data:`BASELINE_RATINGS` (baselines are immutable).
    """
    if not 0.0 <= outcome <= 1.0:
        raise ValueError(
            f"outcome must be in [0, 1], got {outcome!r}."
        )
    if n_games < 1:
        raise ValueError(
            f"n_games must be >= 1, got {n_games!r}."
        )
    if agent in BASELINE_RATINGS:
        raise ValueError(
            f"Cannot update rating for baseline opponent '{agent}'. "
            "Baseline ratings are fixed and cannot be modified."
        )

    r_agent = self.get_rating(agent)
    r_opponent = self.get_rating(opponent)
    n_so_far = self.get_game_count(agent)
    k = k_factor(n_so_far)

    expected = expected_score(r_agent, r_opponent)
    delta = k * (outcome - expected)

    self._ratings[agent] = r_agent + delta
    self._game_counts[agent] = n_so_far + n_games
    return delta

training.elo.TeamEloRegistry

Bases: EloRegistry

Elo registry specialised for multi-agent team ratings.

Extends :class:EloRegistry with team-specific baseline ratings (:data:TEAM_BASELINE_RATINGS). All base-class methods work identically; team baselines are protected against modification just like the single-agent :data:BASELINE_RATINGS.

Typical usage::

from training.elo import TeamEloRegistry

registry = TeamEloRegistry(path="checkpoints/team_elo.json")

# After a self-play evaluation round:
delta = registry.update(
    agent="mappo_blue",
    opponent="self_play_pool",
    outcome=0.6,
    n_games=20,
)

Parameters:

Name Type Description Default
path Union[str, Path, None]

Path to the JSON persistence file. Pass None for an in-memory registry that cannot be saved to disk.

'checkpoints/elo_registry.json'
Source code in training/elo.py
class TeamEloRegistry(EloRegistry):
    """Elo registry specialised for multi-agent team ratings.

    Extends :class:`EloRegistry` with team-specific baseline ratings
    (:data:`TEAM_BASELINE_RATINGS`).  All base-class methods work
    identically; team baselines are protected against modification just
    like the single-agent :data:`BASELINE_RATINGS`.

    Typical usage::

        from training.elo import TeamEloRegistry

        registry = TeamEloRegistry(path="checkpoints/team_elo.json")

        # After a self-play evaluation round:
        delta = registry.update(
            agent="mappo_blue",
            opponent="self_play_pool",
            outcome=0.6,
            n_games=20,
        )

    Parameters
    ----------
    path:
        Path to the JSON persistence file.  Pass ``None`` for an
        in-memory registry that cannot be saved to disk.
    """

    def get_rating(self, name: str) -> float:
        """Return Elo rating for *name*, checking team baselines.

        Look-up order:

        1. Stored ratings (updated agents).
        2. :data:`TEAM_BASELINE_RATINGS` (multi-agent team baselines).
        3. :data:`BASELINE_RATINGS` (single-agent scripted baselines).
        4. :data:`DEFAULT_RATING` fallback.
        """
        if name in self._ratings:
            return self._ratings[name]
        if name in TEAM_BASELINE_RATINGS:
            return TEAM_BASELINE_RATINGS[name]
        return BASELINE_RATINGS.get(name, DEFAULT_RATING)

    def update(
        self,
        agent: str,
        opponent: str,
        outcome: float,
        n_games: int = 1,
    ) -> float:
        """Update team Elo, protecting both BASELINE_RATINGS and TEAM_BASELINE_RATINGS.

        See :meth:`EloRegistry.update` for parameter and return-value
        documentation.

        Raises
        ------
        ValueError
            If *agent* is in :data:`TEAM_BASELINE_RATINGS` (in addition to
            the parent-class guard against :data:`BASELINE_RATINGS`).
        """
        if agent in TEAM_BASELINE_RATINGS:
            raise ValueError(
                f"Cannot update rating for team baseline '{agent}'. "
                "Baseline ratings are fixed and cannot be modified."
            )
        return super().update(agent, opponent, outcome, n_games)

    def __repr__(self) -> str:  # pragma: no cover
        return (
            f"TeamEloRegistry(path={self._path!r}, "
            f"n_agents={len(self._ratings)})"
        )

get_rating(name)

Return Elo rating for name, checking team baselines.

Look-up order:

  1. Stored ratings (updated agents).
  2. :data:TEAM_BASELINE_RATINGS (multi-agent team baselines).
  3. :data:BASELINE_RATINGS (single-agent scripted baselines).
  4. :data:DEFAULT_RATING fallback.
Source code in training/elo.py
def get_rating(self, name: str) -> float:
    """Return Elo rating for *name*, checking team baselines.

    Look-up order:

    1. Stored ratings (updated agents).
    2. :data:`TEAM_BASELINE_RATINGS` (multi-agent team baselines).
    3. :data:`BASELINE_RATINGS` (single-agent scripted baselines).
    4. :data:`DEFAULT_RATING` fallback.
    """
    if name in self._ratings:
        return self._ratings[name]
    if name in TEAM_BASELINE_RATINGS:
        return TEAM_BASELINE_RATINGS[name]
    return BASELINE_RATINGS.get(name, DEFAULT_RATING)

update(agent, opponent, outcome, n_games=1)

Update team Elo, protecting both BASELINE_RATINGS and TEAM_BASELINE_RATINGS.

See :meth:EloRegistry.update for parameter and return-value documentation.

Raises:

Type Description
ValueError

If agent is in :data:TEAM_BASELINE_RATINGS (in addition to the parent-class guard against :data:BASELINE_RATINGS).

Source code in training/elo.py
def update(
    self,
    agent: str,
    opponent: str,
    outcome: float,
    n_games: int = 1,
) -> float:
    """Update team Elo, protecting both BASELINE_RATINGS and TEAM_BASELINE_RATINGS.

    See :meth:`EloRegistry.update` for parameter and return-value
    documentation.

    Raises
    ------
    ValueError
        If *agent* is in :data:`TEAM_BASELINE_RATINGS` (in addition to
        the parent-class guard against :data:`BASELINE_RATINGS`).
    """
    if agent in TEAM_BASELINE_RATINGS:
        raise ValueError(
            f"Cannot update rating for team baseline '{agent}'. "
            "Baseline ratings are fixed and cannot be modified."
        )
    return super().update(agent, opponent, outcome, n_games)

Artifacts

training.artifacts.CheckpointManifest

Append-only JSONL checkpoint manifest for local artifact indexing.

Source code in training/artifacts.py
class CheckpointManifest:
    """Append-only JSONL checkpoint manifest for local artifact indexing."""

    def __init__(self, path: Path) -> None:
        self.path = path

    def _read_rows(self) -> list[dict]:
        if not self.path.exists():
            return []
        rows: list[dict] = []
        for raw_line in self.path.read_text(encoding="utf-8").splitlines():
            line = raw_line.strip()
            if not line:
                continue
            try:
                payload = json.loads(line)
            except json.JSONDecodeError:
                continue
            if isinstance(payload, dict):
                rows.append(payload)
        return rows

    def known_paths(self) -> set[str]:
        """Return all path strings already present in the manifest."""
        return {str(row.get("path", "")) for row in self._read_rows()}

    def has_entry(
        self,
        artifact_path: Path | str,
        *,
        artifact_type: str,
        step: Optional[int],
    ) -> bool:
        """Return whether an identical artifact event is already present."""
        path_value = str(artifact_path)
        for row in self._read_rows():
            if str(row.get("path", "")) != path_value:
                continue
            if str(row.get("type", "")) != str(artifact_type):
                continue
            row_step = row.get("step")
            if row_step == step:
                return True
        return False

    def latest_entry_for_path(self, artifact_path: Path | str) -> Optional[dict]:
        """Return the latest manifest row for a given path, if any."""
        path_value = str(artifact_path)
        matches = [
            row for row in self._read_rows() if str(row.get("path", "")) == path_value
        ]
        if not matches:
            return None

        def _sort_key(row: dict) -> tuple[int, int]:
            step = row.get("step")
            timestamp = row.get("timestamp")
            return (
                int(step) if isinstance(step, int) else -1,
                int(timestamp) if isinstance(timestamp, int) else -1,
            )

        return max(matches, key=_sort_key)

    def append(self, row: dict) -> None:
        """Append a single JSON object row to manifest storage."""
        self.path.parent.mkdir(parents=True, exist_ok=True)
        with self.path.open("a", encoding="utf-8") as fh:
            fh.write(json.dumps(row, sort_keys=True) + "\n")

    def register(
        self,
        artifact_path: Path,
        *,
        artifact_type: str,
        seed: int,
        curriculum_level: int,
        run_id: Optional[str],
        config_hash: str,
        step: Optional[int],
    ) -> bool:
        """Register one artifact path if it is not already indexed."""
        path_value = str(artifact_path)
        if self.has_entry(path_value, artifact_type=artifact_type, step=step):
            return False
        self.append(
            {
                "timestamp": int(time.time()),
                "type": str(artifact_type),
                "path": path_value,
                "seed": int(seed),
                "curriculum_level": int(curriculum_level),
                "run_id": run_id or None,
                "config_hash": str(config_hash),
                "step": int(step) if step is not None else None,
            }
        )
        return True

    def prune_periodic(
        self,
        checkpoint_dir: Path,
        prefix: str,
        keep_last: int,
    ) -> list[Path]:
        """Delete old periodic checkpoints on disk, keeping the *keep_last* newest.

        Only files that are both present in the manifest *and* exist on disk are
        considered.  The ``keep_last`` most recently registered rows (by step,
        then timestamp) are retained; all older ones are deleted.

        Returns the list of paths that were deleted.
        """
        rows = self._read_rows()
        # Collect unique (step, path) pairs for the given prefix & type.
        candidates: list[tuple[int, Path]] = []
        seen: set[str] = set()
        for row in rows:
            if row.get("type") != "periodic":
                continue
            path_str = str(row.get("path", ""))
            if not path_str or path_str in seen:
                continue
            candidate = Path(path_str)
            if not candidate.is_absolute():
                candidate = checkpoint_dir / candidate
            if not candidate.name.startswith(prefix + "_"):
                continue
            if not candidate.exists():
                continue
            step_value = row.get("step")
            sort_step = int(step_value) if isinstance(step_value, int) else -1
            candidates.append((sort_step, candidate))
            seen.add(path_str)

        # Sort descending — highest step first.
        candidates.sort(key=lambda t: t[0], reverse=True)
        to_delete = candidates[keep_last:]
        deleted: list[Path] = []
        for _, p in to_delete:
            try:
                p.unlink(missing_ok=True)
                deleted.append(p)
            except OSError:
                pass
        return deleted

    def prune_self_play_snapshots(
        self,
        pool_dir: Path,
        keep_last: int,
    ) -> list[Path]:
        """Delete old self-play snapshots on disk, keeping the *keep_last* newest.

        Returns the list of paths that were deleted.
        """
        rows = self._read_rows()
        candidates: list[tuple[int, Path]] = []
        seen: set[str] = set()
        for row in rows:
            if row.get("type") != "self_play_snapshot":
                continue
            path_str = str(row.get("path", ""))
            if not path_str or path_str in seen:
                continue
            candidate = Path(path_str)
            if not candidate.is_absolute():
                candidate = pool_dir / candidate
            if not candidate.exists():
                continue
            step_value = row.get("step")
            sort_step = int(step_value) if isinstance(step_value, int) else -1
            candidates.append((sort_step, candidate))
            seen.add(path_str)

        candidates.sort(key=lambda t: t[0], reverse=True)
        to_delete = candidates[keep_last:]
        deleted: list[Path] = []
        for _, p in to_delete:
            try:
                p.unlink(missing_ok=True)
                deleted.append(p)
            except OSError:
                pass
        return deleted

    def latest_periodic(self, checkpoint_dir: Path, prefix: str) -> Optional[Path]:
        """Return latest periodic checkpoint from manifest, if available."""
        rows = self._read_rows()
        best_step = -1
        best_path: Optional[Path] = None
        for row in rows:
            if row.get("type") != "periodic":
                continue
            path_value = str(row.get("path", ""))
            if not path_value:
                continue
            candidate = Path(path_value)
            if not candidate.is_absolute():
                candidate = checkpoint_dir / candidate
            if not candidate.name.startswith(prefix + "_"):
                continue
            step_value = row.get("step")
            if not isinstance(step_value, int):
                continue
            if not candidate.exists():
                continue
            if step_value > best_step:
                best_step = step_value
                best_path = candidate
        return best_path

append(row)

Append a single JSON object row to manifest storage.

Source code in training/artifacts.py
def append(self, row: dict) -> None:
    """Append a single JSON object row to manifest storage."""
    self.path.parent.mkdir(parents=True, exist_ok=True)
    with self.path.open("a", encoding="utf-8") as fh:
        fh.write(json.dumps(row, sort_keys=True) + "\n")

has_entry(artifact_path, *, artifact_type, step)

Return whether an identical artifact event is already present.

Source code in training/artifacts.py
def has_entry(
    self,
    artifact_path: Path | str,
    *,
    artifact_type: str,
    step: Optional[int],
) -> bool:
    """Return whether an identical artifact event is already present."""
    path_value = str(artifact_path)
    for row in self._read_rows():
        if str(row.get("path", "")) != path_value:
            continue
        if str(row.get("type", "")) != str(artifact_type):
            continue
        row_step = row.get("step")
        if row_step == step:
            return True
    return False

known_paths()

Return all path strings already present in the manifest.

Source code in training/artifacts.py
def known_paths(self) -> set[str]:
    """Return all path strings already present in the manifest."""
    return {str(row.get("path", "")) for row in self._read_rows()}

latest_entry_for_path(artifact_path)

Return the latest manifest row for a given path, if any.

Source code in training/artifacts.py
def latest_entry_for_path(self, artifact_path: Path | str) -> Optional[dict]:
    """Return the latest manifest row for a given path, if any."""
    path_value = str(artifact_path)
    matches = [
        row for row in self._read_rows() if str(row.get("path", "")) == path_value
    ]
    if not matches:
        return None

    def _sort_key(row: dict) -> tuple[int, int]:
        step = row.get("step")
        timestamp = row.get("timestamp")
        return (
            int(step) if isinstance(step, int) else -1,
            int(timestamp) if isinstance(timestamp, int) else -1,
        )

    return max(matches, key=_sort_key)

latest_periodic(checkpoint_dir, prefix)

Return latest periodic checkpoint from manifest, if available.

Source code in training/artifacts.py
def latest_periodic(self, checkpoint_dir: Path, prefix: str) -> Optional[Path]:
    """Return latest periodic checkpoint from manifest, if available."""
    rows = self._read_rows()
    best_step = -1
    best_path: Optional[Path] = None
    for row in rows:
        if row.get("type") != "periodic":
            continue
        path_value = str(row.get("path", ""))
        if not path_value:
            continue
        candidate = Path(path_value)
        if not candidate.is_absolute():
            candidate = checkpoint_dir / candidate
        if not candidate.name.startswith(prefix + "_"):
            continue
        step_value = row.get("step")
        if not isinstance(step_value, int):
            continue
        if not candidate.exists():
            continue
        if step_value > best_step:
            best_step = step_value
            best_path = candidate
    return best_path

prune_periodic(checkpoint_dir, prefix, keep_last)

Delete old periodic checkpoints on disk, keeping the keep_last newest.

Only files that are both present in the manifest and exist on disk are considered. The keep_last most recently registered rows (by step, then timestamp) are retained; all older ones are deleted.

Returns the list of paths that were deleted.

Source code in training/artifacts.py
def prune_periodic(
    self,
    checkpoint_dir: Path,
    prefix: str,
    keep_last: int,
) -> list[Path]:
    """Delete old periodic checkpoints on disk, keeping the *keep_last* newest.

    Only files that are both present in the manifest *and* exist on disk are
    considered.  The ``keep_last`` most recently registered rows (by step,
    then timestamp) are retained; all older ones are deleted.

    Returns the list of paths that were deleted.
    """
    rows = self._read_rows()
    # Collect unique (step, path) pairs for the given prefix & type.
    candidates: list[tuple[int, Path]] = []
    seen: set[str] = set()
    for row in rows:
        if row.get("type") != "periodic":
            continue
        path_str = str(row.get("path", ""))
        if not path_str or path_str in seen:
            continue
        candidate = Path(path_str)
        if not candidate.is_absolute():
            candidate = checkpoint_dir / candidate
        if not candidate.name.startswith(prefix + "_"):
            continue
        if not candidate.exists():
            continue
        step_value = row.get("step")
        sort_step = int(step_value) if isinstance(step_value, int) else -1
        candidates.append((sort_step, candidate))
        seen.add(path_str)

    # Sort descending — highest step first.
    candidates.sort(key=lambda t: t[0], reverse=True)
    to_delete = candidates[keep_last:]
    deleted: list[Path] = []
    for _, p in to_delete:
        try:
            p.unlink(missing_ok=True)
            deleted.append(p)
        except OSError:
            pass
    return deleted

prune_self_play_snapshots(pool_dir, keep_last)

Delete old self-play snapshots on disk, keeping the keep_last newest.

Returns the list of paths that were deleted.

Source code in training/artifacts.py
def prune_self_play_snapshots(
    self,
    pool_dir: Path,
    keep_last: int,
) -> list[Path]:
    """Delete old self-play snapshots on disk, keeping the *keep_last* newest.

    Returns the list of paths that were deleted.
    """
    rows = self._read_rows()
    candidates: list[tuple[int, Path]] = []
    seen: set[str] = set()
    for row in rows:
        if row.get("type") != "self_play_snapshot":
            continue
        path_str = str(row.get("path", ""))
        if not path_str or path_str in seen:
            continue
        candidate = Path(path_str)
        if not candidate.is_absolute():
            candidate = pool_dir / candidate
        if not candidate.exists():
            continue
        step_value = row.get("step")
        sort_step = int(step_value) if isinstance(step_value, int) else -1
        candidates.append((sort_step, candidate))
        seen.add(path_str)

    candidates.sort(key=lambda t: t[0], reverse=True)
    to_delete = candidates[keep_last:]
    deleted: list[Path] = []
    for _, p in to_delete:
        try:
            p.unlink(missing_ok=True)
            deleted.append(p)
        except OSError:
            pass
    return deleted

register(artifact_path, *, artifact_type, seed, curriculum_level, run_id, config_hash, step)

Register one artifact path if it is not already indexed.

Source code in training/artifacts.py
def register(
    self,
    artifact_path: Path,
    *,
    artifact_type: str,
    seed: int,
    curriculum_level: int,
    run_id: Optional[str],
    config_hash: str,
    step: Optional[int],
) -> bool:
    """Register one artifact path if it is not already indexed."""
    path_value = str(artifact_path)
    if self.has_entry(path_value, artifact_type=artifact_type, step=step):
        return False
    self.append(
        {
            "timestamp": int(time.time()),
            "type": str(artifact_type),
            "path": path_value,
            "seed": int(seed),
            "curriculum_level": int(curriculum_level),
            "run_id": run_id or None,
            "config_hash": str(config_hash),
            "step": int(step) if step is not None else None,
        }
    )
    return True

training.artifacts.checkpoint_name_prefix(*, seed, curriculum_level, enable_v2)

Return periodic checkpoint prefix for the active naming mode.

Source code in training/artifacts.py
def checkpoint_name_prefix(*, seed: int, curriculum_level: int, enable_v2: bool) -> str:
    """Return periodic checkpoint prefix for the active naming mode."""
    if not enable_v2:
        return "ppo_battalion"
    return f"ppo_battalion_s{int(seed)}_c{int(curriculum_level)}"

training.artifacts.checkpoint_final_stem(*, seed, curriculum_level, enable_v2)

Return final checkpoint stem (without .zip) for the active naming mode.

Source code in training/artifacts.py
def checkpoint_final_stem(*, seed: int, curriculum_level: int, enable_v2: bool) -> str:
    """Return final checkpoint stem (without .zip) for the active naming mode."""
    if not enable_v2:
        return "ppo_battalion_final"
    return f"ppo_battalion_s{int(seed)}_c{int(curriculum_level)}_final"

training.artifacts.checkpoint_best_filename(*, seed, curriculum_level, enable_v2)

Return best checkpoint filename for the active naming mode.

Source code in training/artifacts.py
def checkpoint_best_filename(*, seed: int, curriculum_level: int, enable_v2: bool) -> str:
    """Return best checkpoint filename for the active naming mode."""
    if not enable_v2:
        return "best_model.zip"
    return f"ppo_battalion_s{int(seed)}_c{int(curriculum_level)}_best.zip"

training.artifacts.parse_step_from_checkpoint_name(path)

Extract timesteps from a periodic checkpoint file name.

Source code in training/artifacts.py
def parse_step_from_checkpoint_name(path: Path) -> Optional[int]:
    """Extract timesteps from a periodic checkpoint file name."""
    match = _STEP_PATTERN.search(path.name)
    if match is None:
        return None
    try:
        return int(match.group(1))
    except ValueError:
        return None

Benchmarks

training.wfm1_benchmark.WFM1Benchmark

Run WFM-1 zero-shot, fine-tuned, and specialist evaluations.

Parameters:

Name Type Description Default
config WFM1BenchmarkConfig

Benchmark configuration.

required
Source code in training/wfm1_benchmark.py
class WFM1Benchmark:
    """Run WFM-1 zero-shot, fine-tuned, and specialist evaluations.

    Parameters
    ----------
    config:
        Benchmark configuration.
    """

    def __init__(self, config: WFM1BenchmarkConfig) -> None:
        self.config = config

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def run(
        self,
        wfm1_policy: Optional[Any] = None,
        specialist_policies: Optional[Dict[str, Any]] = None,
    ) -> WFM1BenchmarkSummary:
        """Run all evaluations and return a summary.

        Parameters
        ----------
        wfm1_policy:
            Trained :class:`~models.wfm1.WFM1Policy` (or ``None`` for a
            scripted random-walk baseline).
        specialist_policies:
            Optional mapping from scenario name → pre-trained specialist
            policy.  When ``None``, a scripted baseline is used.

        Returns
        -------
        :class:`WFM1BenchmarkSummary`
        """
        cfg = self.config
        n = min(cfg.n_scenarios, len(HELD_OUT_SCENARIOS))
        scenarios = [
            WFM1BenchmarkScenario.from_dict(d) for d in HELD_OUT_SCENARIOS[:n]
        ]

        results: List[WFM1BenchmarkResult] = []

        for scenario in scenarios:
            env = self._make_env(scenario)
            card = scenario.to_scenario_card()

            # 1. Zero-shot
            t0 = time.perf_counter()
            zs_episodes = self._evaluate(wfm1_policy, env, card, scenario.echelon)
            results.append(
                WFM1BenchmarkResult(
                    scenario_name=scenario.name,
                    condition="zero_shot",
                    **self._aggregate(zs_episodes),
                    elapsed_seconds=time.perf_counter() - t0,
                )
            )

            # 2. Fine-tuned (adapter only)
            ft_policy = self._finetune(wfm1_policy, env, card, scenario.echelon)
            t0 = time.perf_counter()
            ft_episodes = self._evaluate(ft_policy, env, card, scenario.echelon)
            results.append(
                WFM1BenchmarkResult(
                    scenario_name=scenario.name,
                    condition="finetuned",
                    **self._aggregate(ft_episodes),
                    elapsed_seconds=time.perf_counter() - t0,
                    finetune_steps_used=cfg.finetune_steps,
                )
            )

            # 3. Specialist
            spec_policy = (
                specialist_policies.get(scenario.name)
                if specialist_policies
                else None
            )
            t0 = time.perf_counter()
            sp_episodes = self._evaluate(spec_policy, env, None, scenario.echelon)
            results.append(
                WFM1BenchmarkResult(
                    scenario_name=scenario.name,
                    condition="specialist",
                    **self._aggregate(sp_episodes),
                    elapsed_seconds=time.perf_counter() - t0,
                )
            )

            env.close()
            log.info("Scenario %s done.", scenario.name)

        return WFM1BenchmarkSummary(results=results, config=cfg)

    # ------------------------------------------------------------------
    # Internal: environment factory
    # ------------------------------------------------------------------

    def _make_env(self, scenario: WFM1BenchmarkScenario) -> Any:
        """Create an evaluation environment for *scenario*.

        Builds a GIS terrain when ``scenario.terrain_type`` is non-zero
        (a recognised GIS battle site), otherwise uses procedurally generated
        random terrain.  The terrain dimensions are passed through to
        :class:`~envs.battalion_env.BattalionEnv` so that unit-position
        normalisation is consistent with the terrain grid.
        """
        _GIS_SITE_NAMES: Dict[int, str] = {
            TERRAIN_GIS_WATERLOO: "waterloo",
            TERRAIN_GIS_AUSTERLITZ: "austerlitz",
            TERRAIN_GIS_BORODINO: "borodino",
            TERRAIN_GIS_SALAMANCA: "salamanca",
        }

        try:
            from envs.battalion_env import BattalionEnv
            from envs.sim.terrain import TerrainMap

            if scenario.terrain_type in _GIS_SITE_NAMES:
                from data.gis.terrain_importer import GISTerrainBuilder

                terrain = GISTerrainBuilder(
                    site=_GIS_SITE_NAMES[scenario.terrain_type],
                    rows=40,
                    cols=40,
                ).build()
            else:
                terrain = TerrainMap.generate_random(
                    rng=np.random.default_rng(scenario.seed),
                    width=10_000.0,
                    height=10_000.0,
                    rows=40,
                    cols=40,
                    num_hills=4,
                    num_forests=3,
                )

            return BattalionEnv(
                terrain=terrain,
                randomize_terrain=False,
                map_width=terrain.width,
                map_height=terrain.height,
            )
        except Exception as exc:
            log.debug(
                "Could not build real env for scenario %r: %s — using synthetic.",
                scenario.name,
                exc,
                exc_info=True,
            )

        # Synthetic fallback
        n_entities = scenario.n_blue + scenario.n_red
        return _SyntheticBenchmarkEnv(
            n_entities=n_entities,
            seed=scenario.seed,
            ep_length=self.config.max_steps_per_episode,
        )

    # ------------------------------------------------------------------
    # Internal: episode runner
    # ------------------------------------------------------------------

    def _evaluate(
        self,
        policy: Optional[Any],
        env: Any,
        card: Optional[ScenarioCard],
        echelon: int,
    ) -> List[Dict[str, Any]]:
        """Run *n_eval_episodes* and return raw per-episode stats."""
        cfg = self.config
        results = []
        use_wfm1 = isinstance(policy, WFM1Policy)
        use_predict = not use_wfm1 and hasattr(policy, "predict")

        import torch

        # Infer policy device once so all tensors are placed correctly.
        if use_wfm1:
            try:
                _policy_device = next(policy.parameters()).device
            except (StopIteration, AttributeError):
                _policy_device = torch.device("cpu")
        else:
            _policy_device = torch.device("cpu")

        for ep_idx in range(cfg.n_eval_episodes):
            obs, _ = env.reset(seed=ep_idx + 9000)
            terminated = truncated = False
            steps = 0
            won = None

            while not (terminated or truncated) and steps < cfg.max_steps_per_episode:
                if use_wfm1:
                    tokens_np, pm_np = self._obs_to_tokens(obs)
                    t = torch.as_tensor(
                        tokens_np[np.newaxis], dtype=torch.float32
                    ).to(_policy_device)
                    pm = (
                        torch.as_tensor(pm_np[np.newaxis], dtype=torch.bool).to(_policy_device)
                        if pm_np is not None
                        else None
                    )
                    _card_vec = (
                        card.to_tensor(device=_policy_device) if card is not None else None
                    )
                    with torch.no_grad():
                        action, _ = policy.act(
                            t, pad_mask=pm, echelon=echelon,
                            card_vec=_card_vec,
                            deterministic=True,
                        )
                    action = action.squeeze(0).cpu().numpy()
                elif use_predict:
                    action, _ = policy.predict(obs, deterministic=True)
                else:
                    action = self._scripted_action(obs)

                obs, _rew, terminated, truncated, info = env.step(action)
                steps += 1

            # Determine winner from info dict (best-effort)
            if isinstance(info, dict):
                if info.get("blue_routed"):
                    won = False
                elif info.get("red_routed"):
                    won = True
                else:
                    won = None
            results.append({"won": won, "steps": steps})

        return results

    def _obs_to_tokens(
        self, obs: Any
    ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
        """Convert observation to entity token array."""
        if isinstance(obs, np.ndarray):
            if obs.ndim == 2 and obs.shape[-1] == ENTITY_TOKEN_DIM:
                return obs, None
            n = (obs.size + ENTITY_TOKEN_DIM - 1) // ENTITY_TOKEN_DIM
            padded = np.zeros(n * ENTITY_TOKEN_DIM, dtype=np.float32)
            padded[: obs.size] = obs.ravel()
            return padded.reshape(n, ENTITY_TOKEN_DIM), None
        return np.zeros((8, ENTITY_TOKEN_DIM), dtype=np.float32), None

    @staticmethod
    def _scripted_action(obs: Any) -> np.ndarray:
        """Simple deterministic scripted action (advance forward)."""
        return np.array([1.0, 0.0, 0.0], dtype=np.float32)

    # ------------------------------------------------------------------
    # Internal: fine-tuning
    # ------------------------------------------------------------------

    def _finetune(
        self,
        policy: Optional[Any],
        env: Any,
        card: Optional[ScenarioCard],
        echelon: int,
    ) -> Optional[Any]:
        """Adapter-only fine-tuning of WFM-1 for ≤ finetune_steps steps."""
        if not isinstance(policy, WFM1Policy):
            return policy  # not WFM-1 — return unchanged

        import copy
        import torch

        ft_policy: WFM1Policy = copy.deepcopy(policy)
        ft_policy.freeze_base()
        opt = torch.optim.Adam(ft_policy.adapter_parameters(), lr=3e-4)
        steps = 0

        try:
            ft_device = next(ft_policy.parameters()).device
        except StopIteration:
            ft_device = torch.device("cpu")

        obs, _ = env.reset(seed=7777)
        tokens_np, pm_np = self._obs_to_tokens(obs)

        while steps < self.config.finetune_steps:
            # Use deterministic=True so the BC target is the distribution mean,
            # not a stochastic sample — prevents the adapter from chasing noise.
            t = torch.as_tensor(
                tokens_np[np.newaxis], dtype=torch.float32
            ).to(ft_device)
            pm = (
                torch.as_tensor(pm_np[np.newaxis], dtype=torch.bool).to(ft_device)
                if pm_np is not None
                else None
            )
            with torch.no_grad():
                action, _ = ft_policy.act(
                    t, pad_mask=pm, echelon=echelon, card=card,
                    deterministic=True,
                )
            action_np = action.squeeze(0).cpu().numpy()

            obs_new, _rew, terminated, truncated, _info = env.step(action_np)
            done = terminated or truncated

            # Behaviour-cloning loss: match own predictions (self-distillation)
            batch = {
                "tokens": t,
                "actions": action.detach(),
                "pad_mask": pm,
                "echelon": torch.tensor(echelon),
            }
            if card is not None:
                batch["card_vec"] = card.to_tensor(device=ft_device)
            loss = ft_policy.finetune_loss(batch)
            opt.zero_grad()
            loss.backward()
            opt.step()
            steps += 1

            if done:
                obs, _ = env.reset(seed=7777 + steps)
                tokens_np, pm_np = self._obs_to_tokens(obs)
            else:
                obs = obs_new
                tokens_np, pm_np = self._obs_to_tokens(obs)

        ft_policy.unfreeze_base()
        return ft_policy

    # ------------------------------------------------------------------
    # Internal: statistics aggregation
    # ------------------------------------------------------------------

    @staticmethod
    def _aggregate(episodes: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Compute win_rate, mean_steps, std_steps, n_episodes."""
        wins = [e["won"] for e in episodes if e["won"] is not None]
        steps = [e["steps"] for e in episodes]
        win_rate = float(np.mean(wins)) if wins else 0.0
        return {
            "win_rate": win_rate,
            "mean_steps": float(np.mean(steps)) if steps else 0.0,
            "std_steps": float(np.std(steps)) if steps else 0.0,
            "n_episodes": len(episodes),
        }

run(wfm1_policy=None, specialist_policies=None)

Run all evaluations and return a summary.

Parameters:

Name Type Description Default
wfm1_policy Optional[Any]

Trained :class:~models.wfm1.WFM1Policy (or None for a scripted random-walk baseline).

None
specialist_policies Optional[Dict[str, Any]]

Optional mapping from scenario name → pre-trained specialist policy. When None, a scripted baseline is used.

None

Returns:

Type Description
class:`WFM1BenchmarkSummary`
Source code in training/wfm1_benchmark.py
def run(
    self,
    wfm1_policy: Optional[Any] = None,
    specialist_policies: Optional[Dict[str, Any]] = None,
) -> WFM1BenchmarkSummary:
    """Run all evaluations and return a summary.

    Parameters
    ----------
    wfm1_policy:
        Trained :class:`~models.wfm1.WFM1Policy` (or ``None`` for a
        scripted random-walk baseline).
    specialist_policies:
        Optional mapping from scenario name → pre-trained specialist
        policy.  When ``None``, a scripted baseline is used.

    Returns
    -------
    :class:`WFM1BenchmarkSummary`
    """
    cfg = self.config
    n = min(cfg.n_scenarios, len(HELD_OUT_SCENARIOS))
    scenarios = [
        WFM1BenchmarkScenario.from_dict(d) for d in HELD_OUT_SCENARIOS[:n]
    ]

    results: List[WFM1BenchmarkResult] = []

    for scenario in scenarios:
        env = self._make_env(scenario)
        card = scenario.to_scenario_card()

        # 1. Zero-shot
        t0 = time.perf_counter()
        zs_episodes = self._evaluate(wfm1_policy, env, card, scenario.echelon)
        results.append(
            WFM1BenchmarkResult(
                scenario_name=scenario.name,
                condition="zero_shot",
                **self._aggregate(zs_episodes),
                elapsed_seconds=time.perf_counter() - t0,
            )
        )

        # 2. Fine-tuned (adapter only)
        ft_policy = self._finetune(wfm1_policy, env, card, scenario.echelon)
        t0 = time.perf_counter()
        ft_episodes = self._evaluate(ft_policy, env, card, scenario.echelon)
        results.append(
            WFM1BenchmarkResult(
                scenario_name=scenario.name,
                condition="finetuned",
                **self._aggregate(ft_episodes),
                elapsed_seconds=time.perf_counter() - t0,
                finetune_steps_used=cfg.finetune_steps,
            )
        )

        # 3. Specialist
        spec_policy = (
            specialist_policies.get(scenario.name)
            if specialist_policies
            else None
        )
        t0 = time.perf_counter()
        sp_episodes = self._evaluate(spec_policy, env, None, scenario.echelon)
        results.append(
            WFM1BenchmarkResult(
                scenario_name=scenario.name,
                condition="specialist",
                **self._aggregate(sp_episodes),
                elapsed_seconds=time.perf_counter() - t0,
            )
        )

        env.close()
        log.info("Scenario %s done.", scenario.name)

    return WFM1BenchmarkSummary(results=results, config=cfg)

training.wfm1_benchmark.WFM1BenchmarkConfig dataclass

Configuration for the WFM-1 benchmark.

Attributes:

Name Type Description
n_eval_episodes int

Number of evaluation episodes per scenario per condition.

n_scenarios int

Number of held-out scenarios to evaluate (≤ 20).

finetune_steps int

Maximum fine-tuning budget (adapter only) for the "WFM-1 fine-tuned" condition.

max_steps_per_episode int

Episode step limit.

specialist_train_steps int

How many steps to train each specialist baseline (set to 0 to use a scripted baseline instead — strongly recommended for CI).

zero_shot_win_rate_threshold float

Minimum acceptable zero-shot win rate (acceptance criterion).

finetune_recovery_fraction float

Fraction of specialist performance that fine-tuned WFM-1 must reach (acceptance criterion).

Source code in training/wfm1_benchmark.py
@dataclass
class WFM1BenchmarkConfig:
    """Configuration for the WFM-1 benchmark.

    Attributes
    ----------
    n_eval_episodes:
        Number of evaluation episodes per scenario per condition.
    n_scenarios:
        Number of held-out scenarios to evaluate (≤ 20).
    finetune_steps:
        Maximum fine-tuning budget (adapter only) for the "WFM-1 fine-tuned"
        condition.
    max_steps_per_episode:
        Episode step limit.
    specialist_train_steps:
        How many steps to train each specialist baseline (set to 0 to use a
        scripted baseline instead — strongly recommended for CI).
    zero_shot_win_rate_threshold:
        Minimum acceptable zero-shot win rate (acceptance criterion).
    finetune_recovery_fraction:
        Fraction of specialist performance that fine-tuned WFM-1 must reach
        (acceptance criterion).
    """

    n_eval_episodes: int = 50
    n_scenarios: int = 20
    finetune_steps: int = 10_000
    max_steps_per_episode: int = 500
    specialist_train_steps: int = 0  # 0 = scripted baseline
    zero_shot_win_rate_threshold: float = 0.55
    finetune_recovery_fraction: float = 0.80

training.wfm1_benchmark.WFM1BenchmarkSummary dataclass

Aggregated WFM-1 benchmark results across all scenarios.

Attributes:

Name Type Description
results List[WFM1BenchmarkResult]

All individual (scenario, condition) results.

config WFM1BenchmarkConfig

The :class:WFM1BenchmarkConfig used.

Source code in training/wfm1_benchmark.py
@dataclass
class WFM1BenchmarkSummary:
    """Aggregated WFM-1 benchmark results across all scenarios.

    Attributes
    ----------
    results:
        All individual (scenario, condition) results.
    config:
        The :class:`WFM1BenchmarkConfig` used.
    """

    results: List[WFM1BenchmarkResult]
    config: WFM1BenchmarkConfig

    # ------------------------------------------------------------------
    # Aggregated statistics
    # ------------------------------------------------------------------

    def _filter(self, condition: str) -> List[WFM1BenchmarkResult]:
        return [r for r in self.results if r.condition == condition]

    @property
    def mean_zero_shot_win_rate(self) -> float:
        """Mean zero-shot win rate across all scenarios."""
        zs = self._filter("zero_shot")
        return float(np.mean([r.win_rate for r in zs])) if zs else 0.0

    @property
    def mean_finetuned_win_rate(self) -> float:
        """Mean fine-tuned win rate across all scenarios."""
        ft = self._filter("finetuned")
        return float(np.mean([r.win_rate for r in ft])) if ft else 0.0

    @property
    def mean_specialist_win_rate(self) -> float:
        """Mean specialist win rate across all scenarios."""
        sp = self._filter("specialist")
        return float(np.mean([r.win_rate for r in sp])) if sp else 0.0

    @property
    def finetune_recovery(self) -> float:
        """Fine-tuned performance as a fraction of specialist performance.

        Returns 0.0 when no specialist baseline is available.
        """
        sp = self.mean_specialist_win_rate
        if sp < 1e-6:
            return 0.0
        return self.mean_finetuned_win_rate / sp

    # ------------------------------------------------------------------
    # Acceptance criteria
    # ------------------------------------------------------------------

    @property
    def meets_zero_shot_criterion(self) -> bool:
        """Zero-shot win rate ≥ threshold (default 55 %)."""
        return (
            self.mean_zero_shot_win_rate
            >= self.config.zero_shot_win_rate_threshold
        )

    @property
    def meets_finetune_criterion(self) -> bool:
        """Fine-tuned win rate ≥ 80 % of specialist win rate."""
        return self.finetune_recovery >= self.config.finetune_recovery_fraction

    @property
    def all_criteria_met(self) -> bool:
        return self.meets_zero_shot_criterion and self.meets_finetune_criterion

    # ------------------------------------------------------------------
    # Rendering
    # ------------------------------------------------------------------

    def __str__(self) -> str:
        lines = [
            "WFM-1 Benchmark Summary",
            f"  Scenarios evaluated  : {len(self._filter('zero_shot'))}",
            f"  Zero-shot win rate   : {self.mean_zero_shot_win_rate:.1%} "
            f"(threshold ≥ {self.config.zero_shot_win_rate_threshold:.0%}) "
            f"{'✅' if self.meets_zero_shot_criterion else '❌'}",
            f"  Fine-tuned win rate  : {self.mean_finetuned_win_rate:.1%}",
            f"  Specialist win rate  : {self.mean_specialist_win_rate:.1%}",
            f"  Fine-tune recovery   : {self.finetune_recovery:.1%} "
            f"(threshold ≥ {self.config.finetune_recovery_fraction:.0%}) "
            f"{'✅' if self.meets_finetune_criterion else '❌'}",
        ]
        return "\n".join(lines)

    def write_markdown(self, path: Optional[str | Path] = None) -> Path:
        """Write a Markdown report to *path*.

        Defaults to ``docs/wfm1_benchmark.md``.
        """
        if path is None:
            out = _PROJECT_ROOT / "docs" / "wfm1_benchmark.md"
        else:
            out = Path(path)
        out.parent.mkdir(parents=True, exist_ok=True)
        out.write_text(_render_markdown(self), encoding="utf-8")
        return out

finetune_recovery property

Fine-tuned performance as a fraction of specialist performance.

Returns 0.0 when no specialist baseline is available.

mean_finetuned_win_rate property

Mean fine-tuned win rate across all scenarios.

mean_specialist_win_rate property

Mean specialist win rate across all scenarios.

mean_zero_shot_win_rate property

Mean zero-shot win rate across all scenarios.

meets_finetune_criterion property

Fine-tuned win rate ≥ 80 % of specialist win rate.

meets_zero_shot_criterion property

Zero-shot win rate ≥ threshold (default 55 %).

write_markdown(path=None)

Write a Markdown report to path.

Defaults to docs/wfm1_benchmark.md.

Source code in training/wfm1_benchmark.py
def write_markdown(self, path: Optional[str | Path] = None) -> Path:
    """Write a Markdown report to *path*.

    Defaults to ``docs/wfm1_benchmark.md``.
    """
    if path is None:
        out = _PROJECT_ROOT / "docs" / "wfm1_benchmark.md"
    else:
        out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(_render_markdown(self), encoding="utf-8")
    return out

training.transfer_benchmark.TransferBenchmark

Run procedural-baseline, zero-shot, and fine-tuned evaluations.

Parameters:

Name Type Description Default
config TransferEvalConfig

Evaluation configuration.

required
Notes

When policy is None (or a path to a non-existent file) a simple scripted policy is used: blue units advance at full speed toward the nearest red unit. This is sufficient for acceptance-criterion testing in CI without requiring a trained checkpoint.

Source code in training/transfer_benchmark.py
class TransferBenchmark:
    """Run procedural-baseline, zero-shot, and fine-tuned evaluations.

    Parameters
    ----------
    config:
        Evaluation configuration.

    Notes
    -----
    When *policy* is ``None`` (or a path to a non-existent file) a simple
    scripted policy is used: blue units advance at full speed toward the
    nearest red unit.  This is sufficient for acceptance-criterion testing
    in CI without requiring a trained checkpoint.
    """

    def __init__(self, config: TransferEvalConfig) -> None:
        self.config = config

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def run(
        self,
        policy: Optional[Any] = None,
        gis_data_dir: Optional[str | Path] = None,
    ) -> TransferSummary:
        """Run all three evaluation conditions and return a summary.

        Parameters
        ----------
        policy:
            Either ``None`` (scripted fallback), a path to a Stable-Baselines3
            checkpoint (``.zip``), or an already-loaded SB3-compatible policy
            object with a ``predict(obs)`` method.
        gis_data_dir:
            Optional directory containing ``{site}.tif`` / ``{site}.osm``
            files.  When absent the synthetic GIS fallback is used.

        Returns
        -------
        TransferSummary
        """
        loaded_policy = self._resolve_policy(policy)

        procedural_terrain = self._build_procedural_terrain()
        gis_terrain = self._build_gis_terrain(gis_data_dir)

        # 1. Procedural baseline
        t0 = time.perf_counter()
        proc_results = self._evaluate(loaded_policy, procedural_terrain)
        proc = TransferResult(
            condition="procedural_baseline",
            **self._aggregate(proc_results),
            elapsed_seconds=time.perf_counter() - t0,
        )

        # 2. Zero-shot on GIS terrain (no adaptation)
        t0 = time.perf_counter()
        zs_results = self._evaluate(loaded_policy, gis_terrain)
        zero_shot = TransferResult(
            condition="zero_shot_gis",
            **self._aggregate(zs_results),
            elapsed_seconds=time.perf_counter() - t0,
        )

        # 3. Fine-tune then evaluate (fine-tuning is a no-op when policy is
        #    None or when the policy lacks a ``learn`` method)
        t0 = time.perf_counter()
        ft_policy, steps_used = self._finetune(loaded_policy, gis_terrain)
        ft_results = self._evaluate(ft_policy, gis_terrain)
        finetuned = TransferResult(
            condition="finetuned_gis",
            **self._aggregate(ft_results),
            elapsed_seconds=time.perf_counter() - t0,
            finetune_steps_used=steps_used,
        )

        return TransferSummary(
            site=self.config.site,
            procedural=proc,
            zero_shot=zero_shot,
            finetuned=finetuned,
            config=self.config,
        )

    def write_markdown(
        self,
        summary: TransferSummary,
        path: Optional[str | Path] = None,
    ) -> Path:
        """Write the benchmark summary to a Markdown file.

        Parameters
        ----------
        summary:
            Results from :meth:`run`.
        path:
            Output path.  Defaults to
            ``docs/transfer_benchmark_{site}.md``.

        Returns
        -------
        Path
            Absolute path to the written file.
        """
        if path is None:
            docs_dir = _REPO_ROOT / "docs"
            docs_dir.mkdir(parents=True, exist_ok=True)
            path = docs_dir / f"transfer_benchmark_{summary.site}.md"
        out = Path(path)
        out.parent.mkdir(parents=True, exist_ok=True)
        out.write_text(_render_transfer_markdown(summary), encoding="utf-8")
        return out

    # ------------------------------------------------------------------
    # Internal: terrain builders
    # ------------------------------------------------------------------

    def _build_procedural_terrain(self):
        """Build a procedural (random) TerrainMap for the baseline."""
        from envs.sim.terrain import TerrainMap

        cfg = self.config
        return TerrainMap.generate_random(
            rng=np.random.default_rng(cfg.procedural_seed),
            width=10_000.0,
            height=10_000.0,
            rows=cfg.rows,
            cols=cfg.cols,
            num_hills=cfg.n_procedural_hills,
            num_forests=cfg.n_procedural_forests,
        )

    def _build_gis_terrain(self, gis_data_dir: Optional[Any]):
        """Build a GIS TerrainMap for the target site."""
        from data.gis.terrain_importer import GISTerrainBuilder

        return GISTerrainBuilder(
            site=self.config.site,
            rows=self.config.rows,
            cols=self.config.cols,
            srtm_path=(
                Path(gis_data_dir) / f"{self.config.site}.tif"
                if gis_data_dir is not None
                else None
            ),
            osm_path=(
                Path(gis_data_dir) / f"{self.config.site}.osm"
                if gis_data_dir is not None
                else None
            ),
        ).build()

    # ------------------------------------------------------------------
    # Internal: policy helpers
    # ------------------------------------------------------------------

    def _resolve_policy(self, policy: Optional[Any]) -> Optional[Any]:
        """Load a policy from a checkpoint path if needed."""
        if policy is None:
            return None
        if isinstance(policy, (str, Path)):
            path = Path(policy)
            if not path.exists():
                return None
            try:
                from stable_baselines3 import PPO  # type: ignore

                return PPO.load(path)
            except ImportError:
                # Allow running without Stable-Baselines3 by falling back.
                return None
            except Exception as exc:
                # If the file exists but cannot be loaded, surface the error
                # instead of silently downgrading to the baseline.
                raise RuntimeError(
                    f"Failed to load policy checkpoint from {path}"
                ) from exc
        return policy

    def _finetune(
        self,
        policy: Optional[Any],
        terrain,
    ) -> Tuple[Optional[Any], int]:
        """Fine-tune *policy* on *terrain* for up to config.finetune_steps.

        Creates a :class:`~envs.battalion_env.BattalionEnv` with the target
        GIS terrain and attaches it to the policy (via ``set_env``) before
        calling ``learn``.  This ensures that fine-tuning is actually
        performed on the target map rather than on whatever environment the
        policy was originally trained with.

        Returns the (possibly updated) policy and the number of gradient
        steps actually taken.  When *policy* is ``None`` or lacks a
        ``learn`` method, returns ``(policy, 0)`` unchanged.
        """
        if policy is None or not hasattr(policy, "learn"):
            return policy, 0
        try:
            from envs.battalion_env import BattalionEnv

            ft_env = BattalionEnv(terrain=terrain, randomize_terrain=False)
            if hasattr(policy, "set_env"):
                policy.set_env(ft_env)
            policy.learn(total_timesteps=self.config.finetune_steps)
            ft_env.close()
            return policy, self.config.finetune_steps
        except Exception:
            return policy, 0

    # ------------------------------------------------------------------
    # Internal: episode runner
    # ------------------------------------------------------------------

    # Constant used to offset the eval RNG seed from the terrain seed so
    # unit placements are not correlated with the procedural terrain blobs.
    _EVAL_SEED_OFFSET: int = 0xDEAD

    def _evaluate(
        self,
        policy: Optional[Any],
        terrain,
    ) -> List[Dict]:
        """Run *n_eval_episodes* episodes on *terrain* and return raw stats.

        Each episode dict has ``{"winner": int | None, "steps": int}``.

        When *policy* has a ``predict`` method (i.e., a Stable-Baselines3–
        compatible learned policy), evaluation is driven via
        :class:`~envs.battalion_env.BattalionEnv` with the target terrain
        fixed and terrain randomisation disabled.  The winner is inferred
        from the ``blue_routed`` / ``red_routed`` fields in the final
        ``info`` dict.

        When *policy* is ``None`` (or any object without a ``predict``
        method), the :class:`~envs.sim.engine.SimEngine` deterministic
        scripted baseline is used.  Each episode uses a separate RNG seeded
        from :attr:`_EVAL_SEED_OFFSET` so results are fully reproducible.
        Note: without a morale config the scripted engine will typically run
        to ``max_steps`` as a draw; use a learned policy for meaningful
        win-rate measurements.
        """
        from envs.sim.engine import SimEngine
        from envs.sim.battalion import Battalion

        cfg = self.config
        results: List[Dict] = []
        base_seed = int(cfg.procedural_seed ^ self._EVAL_SEED_OFFSET)

        # Treat any object with a .predict method as a learned policy; otherwise
        # fall back to the built-in scripted SimEngine baseline.
        use_learned_policy = hasattr(policy, "predict")

        if use_learned_policy:
            # Import here to avoid a hard dependency when only the scripted
            # baseline is used (e.g., in lightweight CI contexts).
            from envs.battalion_env import BattalionEnv

        for ep_idx in range(cfg.n_eval_episodes):
            if use_learned_policy:
                # Evaluate the learned policy in a Gymnasium environment with a
                # fixed terrain and deterministic seeding.
                env = BattalionEnv(terrain=terrain, randomize_terrain=False)
                obs, _info = env.reset(seed=base_seed + ep_idx)

                terminated = False
                truncated = False
                steps = 0
                last_info: Dict[str, Any] = {}

                while not (terminated or truncated):
                    action, _ = policy.predict(obs, deterministic=True)
                    obs, _reward, terminated, truncated, last_info = env.step(action)
                    steps += 1

                env.close()

                # Infer winner from the final info dict.
                blue_routed = last_info.get("blue_routed", False)
                red_routed = last_info.get("red_routed", False)
                winner: Optional[int]
                if terminated and red_routed and not blue_routed:
                    winner = 0  # blue wins
                elif terminated and blue_routed and not red_routed:
                    winner = 1  # red wins
                else:
                    winner = None  # draw or truncated

                results.append({"winner": winner, "steps": steps})
            else:
                # Scripted baseline: use SimEngine with a per-episode RNG for
                # reproducible evaluations.
                episode_rng = np.random.default_rng(base_seed + ep_idx)
                blue = Battalion(
                    x=terrain.width * 0.25,
                    y=terrain.height * 0.50,
                    theta=0.0,
                    strength=1.0,
                    team=0,
                )
                red = Battalion(
                    x=terrain.width * 0.75,
                    y=terrain.height * 0.50,
                    theta=float(np.pi),
                    strength=1.0,
                    team=1,
                )
                engine = SimEngine(
                    blue,
                    red,
                    terrain=terrain,
                    max_steps=cfg.max_steps_per_episode,
                    rng=episode_rng,
                )
                result = engine.run()
                results.append({"winner": result.winner, "steps": result.steps})

        return results

    @staticmethod
    def _aggregate(results: List[Dict]) -> Dict:
        """Compute win_rate, mean_steps, std_steps, n_episodes."""
        n = len(results)
        if n == 0:
            return {
                "win_rate": 0.0,
                "mean_steps": 0.0,
                "std_steps": 0.0,
                "n_episodes": 0,
            }
        wins = sum(1 for r in results if r["winner"] == 0)
        steps_arr = np.array([r["steps"] for r in results], dtype=np.float32)
        return {
            "win_rate": float(wins) / n,
            "mean_steps": float(steps_arr.mean()),
            "std_steps": float(steps_arr.std()),
            "n_episodes": n,
        }

run(policy=None, gis_data_dir=None)

Run all three evaluation conditions and return a summary.

Parameters:

Name Type Description Default
policy Optional[Any]

Either None (scripted fallback), a path to a Stable-Baselines3 checkpoint (.zip), or an already-loaded SB3-compatible policy object with a predict(obs) method.

None
gis_data_dir Optional[str | Path]

Optional directory containing {site}.tif / {site}.osm files. When absent the synthetic GIS fallback is used.

None

Returns:

Type Description
TransferSummary
Source code in training/transfer_benchmark.py
def run(
    self,
    policy: Optional[Any] = None,
    gis_data_dir: Optional[str | Path] = None,
) -> TransferSummary:
    """Run all three evaluation conditions and return a summary.

    Parameters
    ----------
    policy:
        Either ``None`` (scripted fallback), a path to a Stable-Baselines3
        checkpoint (``.zip``), or an already-loaded SB3-compatible policy
        object with a ``predict(obs)`` method.
    gis_data_dir:
        Optional directory containing ``{site}.tif`` / ``{site}.osm``
        files.  When absent the synthetic GIS fallback is used.

    Returns
    -------
    TransferSummary
    """
    loaded_policy = self._resolve_policy(policy)

    procedural_terrain = self._build_procedural_terrain()
    gis_terrain = self._build_gis_terrain(gis_data_dir)

    # 1. Procedural baseline
    t0 = time.perf_counter()
    proc_results = self._evaluate(loaded_policy, procedural_terrain)
    proc = TransferResult(
        condition="procedural_baseline",
        **self._aggregate(proc_results),
        elapsed_seconds=time.perf_counter() - t0,
    )

    # 2. Zero-shot on GIS terrain (no adaptation)
    t0 = time.perf_counter()
    zs_results = self._evaluate(loaded_policy, gis_terrain)
    zero_shot = TransferResult(
        condition="zero_shot_gis",
        **self._aggregate(zs_results),
        elapsed_seconds=time.perf_counter() - t0,
    )

    # 3. Fine-tune then evaluate (fine-tuning is a no-op when policy is
    #    None or when the policy lacks a ``learn`` method)
    t0 = time.perf_counter()
    ft_policy, steps_used = self._finetune(loaded_policy, gis_terrain)
    ft_results = self._evaluate(ft_policy, gis_terrain)
    finetuned = TransferResult(
        condition="finetuned_gis",
        **self._aggregate(ft_results),
        elapsed_seconds=time.perf_counter() - t0,
        finetune_steps_used=steps_used,
    )

    return TransferSummary(
        site=self.config.site,
        procedural=proc,
        zero_shot=zero_shot,
        finetuned=finetuned,
        config=self.config,
    )

write_markdown(summary, path=None)

Write the benchmark summary to a Markdown file.

Parameters:

Name Type Description Default
summary TransferSummary

Results from :meth:run.

required
path Optional[str | Path]

Output path. Defaults to docs/transfer_benchmark_{site}.md.

None

Returns:

Type Description
Path

Absolute path to the written file.

Source code in training/transfer_benchmark.py
def write_markdown(
    self,
    summary: TransferSummary,
    path: Optional[str | Path] = None,
) -> Path:
    """Write the benchmark summary to a Markdown file.

    Parameters
    ----------
    summary:
        Results from :meth:`run`.
    path:
        Output path.  Defaults to
        ``docs/transfer_benchmark_{site}.md``.

    Returns
    -------
    Path
        Absolute path to the written file.
    """
    if path is None:
        docs_dir = _REPO_ROOT / "docs"
        docs_dir.mkdir(parents=True, exist_ok=True)
        path = docs_dir / f"transfer_benchmark_{summary.site}.md"
    out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)
    out.write_text(_render_transfer_markdown(summary), encoding="utf-8")
    return out

training.transfer_benchmark.TransferEvalConfig dataclass

Configuration for a single transfer evaluation run.

Attributes:

Name Type Description
site str

GIS battle-site identifier — one of "waterloo", "austerlitz", "borodino", "salamanca".

n_eval_episodes int

Number of episodes to run for each evaluation condition.

max_steps_per_episode int

Episode step budget.

finetune_steps int

Maximum fine-tuning steps allowed. The acceptance criterion requires recovery within 500_000 steps.

rows, cols

Terrain grid resolution used by the GIS importer.

procedural_seed int

RNG seed used to generate the procedural baseline terrain.

n_procedural_hills, n_procedural_forests

Procedural terrain complexity parameters.

Source code in training/transfer_benchmark.py
@dataclass(frozen=True)
class TransferEvalConfig:
    """Configuration for a single transfer evaluation run.

    Attributes
    ----------
    site:
        GIS battle-site identifier — one of ``"waterloo"``,
        ``"austerlitz"``, ``"borodino"``, ``"salamanca"``.
    n_eval_episodes:
        Number of episodes to run for each evaluation condition.
    max_steps_per_episode:
        Episode step budget.
    finetune_steps:
        Maximum fine-tuning steps allowed.  The acceptance criterion
        requires recovery within ``500_000`` steps.
    rows, cols:
        Terrain grid resolution used by the GIS importer.
    procedural_seed:
        RNG seed used to generate the procedural baseline terrain.
    n_procedural_hills, n_procedural_forests:
        Procedural terrain complexity parameters.
    """

    site: str = "waterloo"
    n_eval_episodes: int = 50
    max_steps_per_episode: int = 500
    finetune_steps: int = 500_000
    rows: int = 40
    cols: int = 40
    procedural_seed: int = 42
    n_procedural_hills: int = 4
    n_procedural_forests: int = 3

training.transfer_benchmark.TransferSummary dataclass

Aggregated transfer benchmark results.

Attributes:

Name Type Description
site str

GIS battle-site that was evaluated.

procedural TransferResult

Evaluation on procedural terrain (the training distribution).

zero_shot TransferResult

Evaluation on GIS terrain with no adaptation.

finetuned TransferResult

Evaluation on GIS terrain after fine-tuning.

config TransferEvalConfig

The :class:TransferEvalConfig used.

Source code in training/transfer_benchmark.py
@dataclass
class TransferSummary:
    """Aggregated transfer benchmark results.

    Attributes
    ----------
    site:
        GIS battle-site that was evaluated.
    procedural:
        Evaluation on procedural terrain (the training distribution).
    zero_shot:
        Evaluation on GIS terrain with no adaptation.
    finetuned:
        Evaluation on GIS terrain after fine-tuning.
    config:
        The :class:`TransferEvalConfig` used.
    """

    site: str
    procedural: TransferResult
    zero_shot: TransferResult
    finetuned: TransferResult
    config: TransferEvalConfig

    # ------------------------------------------------------------------
    # Acceptance criteria
    # ------------------------------------------------------------------

    @property
    def zero_shot_drop(self) -> float:
        """Win-rate drop from procedural → zero-shot (positive = drop)."""
        return self.procedural.win_rate - self.zero_shot.win_rate

    @property
    def finetuned_drop(self) -> float:
        """Win-rate drop from procedural → fine-tuned (positive = drop)."""
        return self.procedural.win_rate - self.finetuned.win_rate

    @property
    def meets_zero_shot_criterion(self) -> bool:
        """Zero-shot drop < 20 percentage points."""
        return self.zero_shot_drop < 0.20

    @property
    def meets_finetune_criterion(self) -> bool:
        """Fine-tuned drop < 5 pp AND fine-tuning used ≤ 500 k steps."""
        return (
            self.finetuned_drop < 0.05
            and self.finetuned.finetune_steps_used <= self.config.finetune_steps
        )

    @property
    def all_criteria_met(self) -> bool:
        """Both acceptance criteria are satisfied."""
        return self.meets_zero_shot_criterion and self.meets_finetune_criterion

    def __str__(self) -> str:
        lines = [
            f"Transfer Benchmark — site: {self.site}",
            f"  Procedural baseline : {self.procedural.win_rate:.1%} win-rate",
            f"  Zero-shot GIS       : {self.zero_shot.win_rate:.1%} win-rate "
            f"(drop {self.zero_shot_drop:+.1%}) "
            f"{'✅' if self.meets_zero_shot_criterion else '❌'}",
            f"  Fine-tuned GIS      : {self.finetuned.win_rate:.1%} win-rate "
            f"(drop {self.finetuned_drop:+.1%}, "
            f"{self.finetuned.finetune_steps_used:,} steps) "
            f"{'✅' if self.meets_finetune_criterion else '❌'}",
        ]
        return "\n".join(lines)

all_criteria_met property

Both acceptance criteria are satisfied.

finetuned_drop property

Win-rate drop from procedural → fine-tuned (positive = drop).

meets_finetune_criterion property

Fine-tuned drop < 5 pp AND fine-tuning used ≤ 500 k steps.

meets_zero_shot_criterion property

Zero-shot drop < 20 percentage points.

zero_shot_drop property

Win-rate drop from procedural → zero-shot (positive = drop).

training.historical_benchmark.HistoricalBenchmark

Run all 50+ historical scenarios and collect comparison results.

Parameters:

Name Type Description Default
battles_path str | Path

Path to the JSON battle database. Defaults to data/historical/battles.json relative to the repository root.

_BATTLES_JSON
seed int

Random seed passed to :class:~envs.sim.engine.SimEngine for reproducible results.

42
Source code in training/historical_benchmark.py
class HistoricalBenchmark:
    """Run all 50+ historical scenarios and collect comparison results.

    Parameters
    ----------
    battles_path:
        Path to the JSON battle database.  Defaults to
        ``data/historical/battles.json`` relative to the repository root.
    seed:
        Random seed passed to :class:`~envs.sim.engine.SimEngine` for
        reproducible results.
    """

    def __init__(
        self,
        battles_path: str | Path = _BATTLES_JSON,
        seed: int = 42,
    ) -> None:
        self.battles_path = Path(battles_path)
        self.seed = seed

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def run(self) -> BenchmarkSummary:
        """Run the full benchmark and return a :class:`BenchmarkSummary`.

        Each scenario is run as a 1v1 simulation using the first blue
        battalion against the first red battalion (matching the existing
        test pattern in ``tests/test_historical_scenarios.py``).
        """
        importer = BatchScenarioImporter(self.battles_path)
        records = importer.load_records()
        scenarios = [r.to_scenario() for r in records]

        entries: List[BenchmarkEntry] = []
        overall_start = time.perf_counter()

        for rec, scenario in zip(records, scenarios):
            entry = self._run_scenario(rec.battle_id, rec.source, scenario)
            entries.append(entry)

        total_elapsed = time.perf_counter() - overall_start

        # Aggregate
        passed_entries = [e for e in entries if e.passed]
        fidelity_scores = [e.fidelity_score for e in passed_entries]
        winner_matches = [e.winner_matches for e in passed_entries]

        mean_fidelity = float(np.mean(fidelity_scores)) if fidelity_scores else 0.0
        std_fidelity = float(np.std(fidelity_scores)) if fidelity_scores else 0.0
        winner_match_rate = (
            float(np.mean([float(w) for w in winner_matches]))
            if winner_matches else 0.0
        )

        return BenchmarkSummary(
            total=len(entries),
            passed=len(passed_entries),
            winner_match_rate=winner_match_rate,
            mean_fidelity=mean_fidelity,
            std_fidelity=std_fidelity,
            total_elapsed_seconds=total_elapsed,
            entries=entries,
        )

    def write_markdown(
        self,
        summary: BenchmarkSummary,
        output_path: str | Path = _BENCHMARK_MD,
    ) -> Path:
        """Write the benchmark results to a Markdown file.

        Parameters
        ----------
        summary:
            The :class:`BenchmarkSummary` returned by :meth:`run`.
        output_path:
            Destination file path.  Parent directories are created if
            they do not already exist.

        Returns
        -------
        Path
            The path to the written file.
        """
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        output_path.write_text(
            _render_markdown(summary),
            encoding="utf-8",
        )
        return output_path

    # ------------------------------------------------------------------
    # Internal
    # ------------------------------------------------------------------

    def _run_scenario(self, battle_id: str, source: str, scenario) -> BenchmarkEntry:
        """Run a single scenario and return a :class:`BenchmarkEntry`."""
        start = time.perf_counter()
        entry = BenchmarkEntry(
            battle_id=battle_id,
            scenario_name=scenario.name,
            date=scenario.date,
            source=source,
            historical_winner=scenario.historical_outcome.winner,
        )
        try:
            blue_battalions, red_battalions = scenario.build_battalions()
            terrain = scenario.build_terrain()
            rng = np.random.default_rng(self.seed)
            result: EpisodeResult = SimEngine(
                blue_battalions[0],
                red_battalions[0],
                terrain=terrain,
                rng=rng,
            ).run()
            comparator = OutcomeComparator(scenario.historical_outcome)
            entry.comparison = comparator.compare(result)
        except Exception as exc:  # noqa: BLE001
            entry.error = str(exc)
        finally:
            entry.elapsed_seconds = time.perf_counter() - start
        return entry

run()

Run the full benchmark and return a :class:BenchmarkSummary.

Each scenario is run as a 1v1 simulation using the first blue battalion against the first red battalion (matching the existing test pattern in tests/test_historical_scenarios.py).

Source code in training/historical_benchmark.py
def run(self) -> BenchmarkSummary:
    """Run the full benchmark and return a :class:`BenchmarkSummary`.

    Each scenario is run as a 1v1 simulation using the first blue
    battalion against the first red battalion (matching the existing
    test pattern in ``tests/test_historical_scenarios.py``).
    """
    importer = BatchScenarioImporter(self.battles_path)
    records = importer.load_records()
    scenarios = [r.to_scenario() for r in records]

    entries: List[BenchmarkEntry] = []
    overall_start = time.perf_counter()

    for rec, scenario in zip(records, scenarios):
        entry = self._run_scenario(rec.battle_id, rec.source, scenario)
        entries.append(entry)

    total_elapsed = time.perf_counter() - overall_start

    # Aggregate
    passed_entries = [e for e in entries if e.passed]
    fidelity_scores = [e.fidelity_score for e in passed_entries]
    winner_matches = [e.winner_matches for e in passed_entries]

    mean_fidelity = float(np.mean(fidelity_scores)) if fidelity_scores else 0.0
    std_fidelity = float(np.std(fidelity_scores)) if fidelity_scores else 0.0
    winner_match_rate = (
        float(np.mean([float(w) for w in winner_matches]))
        if winner_matches else 0.0
    )

    return BenchmarkSummary(
        total=len(entries),
        passed=len(passed_entries),
        winner_match_rate=winner_match_rate,
        mean_fidelity=mean_fidelity,
        std_fidelity=std_fidelity,
        total_elapsed_seconds=total_elapsed,
        entries=entries,
    )

write_markdown(summary, output_path=_BENCHMARK_MD)

Write the benchmark results to a Markdown file.

Parameters:

Name Type Description Default
summary BenchmarkSummary

The :class:BenchmarkSummary returned by :meth:run.

required
output_path str | Path

Destination file path. Parent directories are created if they do not already exist.

_BENCHMARK_MD

Returns:

Type Description
Path

The path to the written file.

Source code in training/historical_benchmark.py
def write_markdown(
    self,
    summary: BenchmarkSummary,
    output_path: str | Path = _BENCHMARK_MD,
) -> Path:
    """Write the benchmark results to a Markdown file.

    Parameters
    ----------
    summary:
        The :class:`BenchmarkSummary` returned by :meth:`run`.
    output_path:
        Destination file path.  Parent directories are created if
        they do not already exist.

    Returns
    -------
    Path
        The path to the written file.
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(
        _render_markdown(summary),
        encoding="utf-8",
    )
    return output_path