Skip to content

API Reference

Full auto-generated API documentation for all stable public interfaces. All four packages expose their symbols at the top-level namespace so you only need a single import per package.

from envs     import BattalionEnv, RewardWeights, Formation
from models   import BattalionMlpPolicy, MAPPOPolicy, WFM1Policy
from training import train, TrainingConfig, evaluate, OpponentPool
from analysis import COAGenerator, SaliencyAnalyzer, compute_gradient_saliency

Packages

  • envs — simulation environments, reward shaping, sim primitives, HRL options
  • models — MLP, MAPPO, transformer, recurrent, and WFM-1 policies
  • training — training runners, callbacks, evaluation, self-play, Elo, benchmarks
  • analysis — COA generation and policy saliency
  • benchmarks — WargamesBench runner

Top-level envs symbols

envs

Wargames Training — environments public API.

Stable interfaces for all simulation environments, reward shaping, simulation primitives, and scenario utilities. Import from this module rather than from individual submodules to stay insulated from internal restructuring.

Environments

:class:BattalionEnv Single-battalion 1v1 Gymnasium environment (continuous action space).

:class:BrigadeEnv Brigade-level multi-battalion environment (MultiDiscrete actions).

:class:DivisionEnv Division-level multi-brigade environment (MultiDiscrete actions).

:class:CorpsEnv Corps-level operational environment with supply and road networks.

:class:CavalryCorpsEnv Corps environment extended with cavalry reconnaissance.

:class:ArtilleryCorpsEnv Corps environment extended with artillery and fortification mechanics.

:class:MultiBattalionEnv Vectorised multi-battalion PettingZoo parallel environment.

Reward shaping

:class:RewardWeights Dataclass of per-component reward multipliers.

:class:RewardComponents Named-tuple of per-step reward component values.

:func:compute_reward Compute a single-step reward given state deltas and weights.

Environment configuration types

:class:LogisticsConfig / :class:LogisticsState / :class:SupplyWagon Logistics and supply management.

:class:MoraleConfig Morale dynamics parameters.

:class:Formation Unit formation enum (LINE / COLUMN / SQUARE / SKIRMISH).

:class:WeatherConfig / :class:WeatherState Weather simulation parameters.

:class:RedPolicy Protocol for scripted / learned Red opponent policies.

Simulation primitives

:class:SimEngine Headless deterministic 1v1 simulation engine.

:class:EpisodeResult Structured result from a :class:SimEngine episode.

HRL options framework

:class:MacroAction Discrete high-level action enum.

:class:Option Temporal abstraction option (initiation set, policy, termination).

:func:make_default_options Build the default option set for the battalion environment.

:class:SMDPWrapper Semi-MDP wrapper that executes options as macro-actions.

Corps constants

:data:CORPS_OBS_DIM, :data:N_CORPS_SECTORS, :data:N_OBJECTIVES, :data:CORPS_MAP_WIDTH, :data:CORPS_MAP_HEIGHT, :data:N_ROAD_FEATURES

BattalionEnv

Bases: Env

1v1 battalion RL environment.

The agent controls the Blue battalion; Red is driven either by a built-in scripted opponent (controlled by curriculum_level) or by an optional red_policy (any object with a predict method, e.g. a Stable-Baselines3 PPO model). When red_policy is provided it takes precedence over the scripted opponent and curriculum_level is ignored for movement and fire decisions.

Curriculum levels (scripted opponent only) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ===== ============================================= Level Red opponent behaviour ===== ============================================= 1 Stationary — Red does not move or fire. 2 Turning only — Red faces Blue but stays put. 3 Advance only — Red turns and advances; no fire. 4 Soft fire — Red turns, advances, fires at 50 % intensity. 5 Full combat — Red turns, advances, fires at 100 % intensity (default). ===== =============================================

Observation space — Box(shape=(17,), dtype=float32) (formations disabled) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ===== =========================== ========= Index Feature Range ===== =========================== ========= 0 blue x / map_width [0, 1] 1 blue y / map_height [0, 1] 2 cos(blue θ) [-1, 1] 3 sin(blue θ) [-1, 1] 4 blue strength [0, 1] 5 blue morale [0, 1] 6 distance to red / diagonal [0, 1] 7 cos(bearing to red) [-1, 1] 8 sin(bearing to red) [-1, 1] 9 red strength [0, 1] 10 red morale [0, 1] 11 step / max_steps [0, 1] 12 blue elevation (normalised) [0, 1] 13 blue cover [0, 1] 14 red elevation (normalised) [0, 1] 15 red cover [0, 1] 16 line-of-sight (0=blocked) [0, 1] ===== =========================== =========

When enable_formations is True, two extra dimensions are appended:

===== ===================================== ========= Index Feature Range ===== ===================================== ========= 17 blue formation / (NUM_FORMATIONS−1) [0, 1] 18 blue transitioning (1=yes) [0, 1] ===== ===================================== =========

When enable_logistics is True, three more dimensions follow the terrain / formations dims:

===== =========================== ========= Index Feature Range ===== =========================== ========= N+0 blue ammo level [0, 1] N+1 blue food level [0, 1] N+2 blue fatigue level [0, 1] ===== =========================== =========

(where N = 17 with formations disabled or 19 with formations enabled).

When enable_weather is True, two more dimensions are appended after any logistics dims:

===== =========================== ========= Index Feature Range ===== =========================== ========= M+0 weather condition id [0, 1] M+1 combined visibility fraction [0, 1] ===== =========================== =========

(where M = N + 3 when logistics are enabled, or N when disabled).

Action space — Box(shape=(3,), dtype=float32) (formations disabled) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ===== =========== ======== ============================================ Index Name Range Effect ===== =========== ======== ============================================ 0 move [-1, 1] Scale max_speed; positive = forward 1 rotate [-1, 1] Scale max_turn_rate; positive = CCW 2 fire [0, 1] Fire intensity this step ===== =========== ======== ============================================

When enable_formations is True, a fourth dimension is added:

===== =========== ======== ============================================ Index Name Range Effect ===== =========== ======== ============================================ 3 formation [0, 3] Desired formation (rounded to nearest int) ===== =========== ======== ============================================

Parameters:

Name Type Description Default
map_width float

Map dimensions in metres (default 1 km × 1 km).

MAP_WIDTH
map_height float

Map dimensions in metres (default 1 km × 1 km).

MAP_WIDTH
max_steps int

Episode length cap (default 500).

MAX_STEPS
terrain Optional[TerrainMap]

Optional :class:~envs.sim.terrain.TerrainMap. When supplied, randomize_terrain is forced to False and this fixed map is used for every episode. Defaults to a flat open plain.

None
randomize_terrain bool

When True (default) and no fixed terrain is supplied, a new procedural terrain is generated from the seeded RNG at the start of each episode. Set to False to keep a static flat plain.

True
hill_speed_factor float

Movement speed multiplier applied to units on maximum-elevation terrain. Must be in (0, 1]. A value of 0.5 means units on the highest hill travel at half their normal speed; 1.0 disables the hill penalty entirely.

0.5
curriculum_level int

Scripted Red opponent difficulty (1–5). Level 1 is the easiest (stationary target); level 5 is full combat. Defaults to 5. Ignored when red_policy is provided.

5
reward_weights Optional[RewardWeights]

:class:~envs.reward.RewardWeights instance with per-component multipliers. Defaults to RewardWeights() (standard shaped reward with the legacy coefficients).

None
red_policy Optional[RedPolicy]

Optional policy object for driving the Red battalion. Must expose a predict(obs, deterministic=False) -> (action, state) method (satisfied by any SB3 model or policy). When supplied, the scripted opponent is bypassed. Use :meth:set_red_policy to swap the policy at runtime (e.g. from a training callback).

None
render_mode Optional[str]

Render mode. None disables rendering. "human" opens a pygame window and renders the simulation in real time; each call to :meth:render displays the current frame.

None
enable_formations bool

When True, the formation system is activated. The action space gains a 4th dimension (desired formation index 0–3) and the observation space gains two extra dimensions (current formation normalised and a transitioning flag). Formation modifiers are applied to firepower, movement speed, and morale resilience (the morale hit from casualties is divided by the unit's morale_resilience before being fed into the morale check). Defaults to False to preserve full backward compatibility.

False
enable_logistics bool

When True, the supply, ammunition, and fatigue model is activated. The observation space gains three extra dimensions (blue ammo, food, fatigue — all normalised to [0, 1]). Ammunition is consumed when firing (weapon jams at zero); fatigue accumulates from movement and combat, penalising speed and accuracy; and battalions can resupply by halting near a friendly supply wagon. Pass a :class:~envs.sim.logistics.LogisticsConfig via logistics_config to tune the rates; if None a default config is used. Defaults to False to preserve full backward compatibility.

False
logistics_config Optional[LogisticsConfig]

Optional :class:~envs.sim.logistics.LogisticsConfig instance. Used when enable_logistics is True. Defaults to LogisticsConfig() if not supplied.

None
enable_weather bool

When True, the weather and time-of-day model is activated. A :class:~envs.sim.weather.WeatherState is sampled at each reset() (or fixed via weather_config). The observation space gains two extra dimensions: the normalised weather condition id and the combined visibility fraction. Weather modifiers are applied to LOS range, fire accuracy, movement speed, and morale. Defaults to False to preserve full backward compatibility.

False
weather_config Optional[WeatherConfig]

Optional :class:~envs.sim.weather.WeatherConfig instance. Used when enable_weather is True. Defaults to WeatherConfig() if not supplied.

None

close()

Clean up resources, including the pygame window if open.

render()

Render the current environment state.

When render_mode="human" a pygame window is opened on the first call and kept alive for subsequent calls. The window is closed by :meth:close. When render_mode is None this is a no-op.

reset(*, seed=None, options=None)

Reset the environment and return the initial observation.

When randomize_terrain is True (the default when no fixed terrain map is passed to __init__), a new procedural terrain is generated from the seeded RNG on every call. Passing the same seed therefore always produces the same terrain layout and unit positions.

Blue spawns in the western half of the map facing roughly east; Red spawns in the eastern half facing roughly west.

set_red_policy(policy)

Swap the Red opponent policy at runtime.

Parameters:

Name Type Description Default
policy Optional[RedPolicy]

New policy to use for Red, or None to revert to the scripted opponent.

required

step(action)

Advance the environment by one step.

Parameters:

Name Type Description Default
action ndarray

Array of shape (3,): [move, rotate, fire].

required

Returns:

Type Description
(observation, reward, terminated, truncated, info)

RewardWeights dataclass

Configurable multipliers for each reward component.

All weights are applied by multiplying against their respective raw component value before summing into the total reward. Set a weight to 0.0 to disable that component entirely.

Parameters:

Name Type Description Default
delta_enemy_strength float

Multiplier applied to the fraction of enemy strength destroyed in a step (dmg_b2r). Encourages the agent to deal damage.

5.0
delta_own_strength float

Multiplier applied to the fraction of own strength lost in a step (dmg_r2b). The contribution is negated before summing.

5.0
survival_bonus float

Per-step bonus scaled by Blue's current strength. Set to 0.0 (default) to disable; a small positive value (e.g. 0.005) rewards staying alive longer.

0.0
win_bonus float

Terminal reward added when Blue wins (Red routed or destroyed).

10.0
loss_penalty float

Terminal reward added when Blue loses (Blue routed or destroyed). Should be negative.

-10.0
time_penalty float

Constant added every step. A small negative value (e.g. -0.01) discourages unnecessary stalling.

-0.01

enemy_routed_bonus = 0.0 class-attribute instance-attribute

Bonus added each step that the enemy is in a routing state.

A positive value (e.g. 2.0) encourages pursuit of routing enemies and exploitation of broken units. Set to 0.0 (default) to disable.

own_routing_penalty = 0.0 class-attribute instance-attribute

Penalty added each step that the agent's own unit is routing.

A negative value (e.g. -2.0) discourages allowing the agent's battalion to be broken. Set to 0.0 (default) to disable.

Formation

Bases: IntEnum

Discrete formation states for a Napoleonic infantry battalion.

The integer values are stable indices used in observation vectors and action spaces; do not change them without updating downstream code.

========== === ==================================================== Formation Int Historical role ========== === ==================================================== LINE 0 Two- or three-rank line — maximum firepower front. COLUMN 1 Attack or march column — highest movement speed. SQUARE 2 Hollow square — impenetrable to unsupported cavalry. SKIRMISH 3 Extended order — loose screen, independent aimed fire. ========== === ====================================================

SimEngine

Run a 1v1 battalion episode to completion.

Each :meth:step:

  1. Resets per-step damage accumulators on both :class:CombatState objects.
  2. Computes fire damage from both sides simultaneously (before applying either, so there are no ordering effects on damage calculation).
  3. Applies terrain cover to reduce incoming damage at each unit's position.
  4. Applies casualties and updates strength.
  5. Runs a morale check for each unit, potentially triggering routing.

The episode ends (:meth:is_over returns True) when:

  • Either side's :class:CombatState reports is_routing = True, or
  • Either side's strength falls to :data:DESTROYED_THRESHOLD or below, or
  • step_count reaches max_steps.

Parameters:

Name Type Description Default
blue Battalion

The two opposing battalions. Modified in-place by each step.

required
red Battalion

The two opposing battalions. Modified in-place by each step.

required
terrain Optional[TerrainMap]

Optional :class:~envs.sim.terrain.TerrainMap. Defaults to a flat 1 km × 1 km open plain.

None
max_steps int

Hard cap on episode length (default 500, matching acceptance criterion AC-1 of Epic E1.2).

500
rng Optional[Generator]

Seeded random generator. Defaults to a fresh unseeded generator. Pass a seeded generator for reproducible results.

None

is_over()

Return True when the episode should end.

run()

Run the episode to completion and return a result summary.

step()

Advance one simulation step.

Returns:

Type Description
dict with keys:

blue_damage_dealt – actual damage blue dealt to red this step. red_damage_dealt – actual damage red dealt to blue this step. blue_routing – whether blue is routing after this step. red_routing – whether red is routing after this step.

Top-level models symbols

models

Neural network models for wargames_training.

Exports

MLP policy (SB3-compatible) :class:~models.mlp_policy.BattalionMlpPolicy — MLP actor-critic for :class:~envs.battalion_env.BattalionEnv.

MAPPO multi-agent policy :class:~models.mappo_policy.MAPPOActor — stochastic Gaussian actor. :class:~models.mappo_policy.MAPPOCritic — centralised value critic. :class:~models.mappo_policy.MAPPOPolicy — combined actor-critic policy.

Entity encoder (transformer) :data:~models.entity_encoder.ENTITY_TOKEN_DIM — token dimensionality. :class:~models.entity_encoder.SpatialPositionalEncoding — 2D Fourier encoding. :class:~models.entity_encoder.EntityEncoder — multi-head transformer encoder. :class:~models.entity_encoder.EntityActorCriticPolicy — entity-based actor-critic.

Recurrent policy (LSTM) :class:~models.recurrent_policy.LSTMHiddenState — LSTM hidden-state container. :class:~models.recurrent_policy.RecurrentEntityEncoder — entity encoder + LSTM. :class:~models.recurrent_policy.RecurrentActorCriticPolicy — LSTM actor-critic. :class:~models.recurrent_policy.RecurrentRolloutBuffer — BPTT rollout buffer.

WFM-1 foundation model :class:~models.wfm1.WFM1Policy — multi-echelon foundation model. :class:~models.wfm1.ScenarioCard — scenario descriptor for WFM-1. :class:~models.wfm1.EchelonEncoder — per-echelon entity encoder. :class:~models.wfm1.CrossEchelonTransformer — cross-echelon transformer.

BattalionMlpPolicy

Bases: ActorCriticPolicy

MLP actor-critic policy for BattalionEnv.

A configurable fully-connected network with separate actor and critic heads, designed for the 12-dimensional observation space of :class:~envs.battalion_env.BattalionEnv.

Architecture (default)::

obs(12) → Linear(128) → Tanh → Linear(128) → Tanh
        ↓                                         ↓
actor head → action(3) + log_std        critic head → value(1)

Parameters:

Name Type Description Default
observation_space Space

Gymnasium observation space.

required
action_space Space

Gymnasium action space.

required
lr_schedule Schedule

Learning-rate schedule passed in by the SB3 algorithm.

required
net_arch Optional[List[int]]

Hidden-layer sizes shared by actor and critic. Defaults to [128, 128].

None
activation_fn Type[Module]

Activation function class applied after each hidden layer. Defaults to :class:torch.nn.Tanh.

Tanh
**kwargs Any

Forwarded to :class:~stable_baselines3.common.policies.ActorCriticPolicy.

{}

MAPPOPolicy

Bases: Module

MAPPO policy: shared actor(s) plus a centralized critic.

Supports two parameter-sharing modes for the actor:

  • share_parameters=True (default) — all n_agents agents use the same :class:MAPPOActor weights. Memory scales as O(actor_params + critic_params) instead of O(n_agents * actor_params).
  • share_parameters=False — each agent gets its own :class:MAPPOActor; useful for ablation studies.

In both cases there is a single :class:MAPPOCritic that all agents share.

Parameters:

Name Type Description Default
obs_dim int

Per-agent local observation dimensionality.

required
action_dim int

Per-agent action dimensionality (continuous).

required
state_dim int

Global state dimensionality.

required
n_agents int

Number of controlled agents (used only when share_parameters=False to build separate actor heads).

1
share_parameters bool

Whether all agents share one actor. Defaults to True.

True
actor_hidden_sizes Tuple[int, ...]

Hidden layer sizes for the actor trunk.

(128, 64)
critic_hidden_sizes Tuple[int, ...]

Hidden layer sizes for the critic trunk.

(128, 64)

act(obs, agent_idx=0, deterministic=False)

Sample actions for a batch of observations.

A batch dimension is added internally when obs is 1-D so that the return shapes are always (batch, action_dim) and (batch,) regardless of whether the caller passes a single observation or a batch.

Parameters:

Name Type Description Default
obs Tensor

Local observations of shape (batch, obs_dim) or (obs_dim,) for a single observation.

required
agent_idx int

Index of the agent whose actor should be used. Ignored when share_parameters=True.

0
deterministic bool

When True returns the distribution mean instead of sampling.

False

Returns:

Name Type Description
actions torch.Tensor — shape ``(batch, action_dim)``
log_probs torch.Tensor — shape ``(batch,)``

evaluate_actions(obs, actions, state, agent_idx=0)

Evaluate actions under the current policy for the PPO update.

Parameters:

Name Type Description Default
obs Tensor

Local observations of shape (batch, obs_dim).

required
actions Tensor

Actions of shape (batch, action_dim).

required
state Tensor

Global states of shape (batch, state_dim).

required
agent_idx int

Actor index (ignored when sharing parameters).

0

Returns:

Name Type Description
log_probs torch.Tensor — shape ``(batch,)``
entropy torch.Tensor — shape ``(batch,)``
values torch.Tensor — shape ``(batch,)``

get_actor(agent_idx=0)

Return the actor for agent_idx.

When share_parameters=True the same actor is returned regardless of agent_idx.

get_value(state)

Return value estimate(s) for the given global state(s).

Parameters:

Name Type Description Default
state Tensor

Global state tensor of shape (state_dim,) or (batch, state_dim).

required

Returns:

Name Type Description
values torch.Tensor — scalar or shape ``(batch,)``

parameter_count()

Return a dict with actor and critic parameter counts.

EntityEncoder

Bases: Module

Multi-head self-attention encoder over a variable-length entity sequence.

Architecture::

entity tokens (B, N, token_dim)
      
token_embed: Linear(token_dim  d_model)
       + SpatialPositionalEncoding(x, y) [optional]
      
TransformerEncoder (n_layers × TransformerEncoderLayer)
   └── MultiheadAttention (n_heads, d_model)
   └── FFN (dim_feedforward = 4 * d_model)
   └── LayerNorm + residual
      
mean-pool over non-padded entities    (B, d_model)
      
output projection: Linear(d_model  d_model)  [identity-initialised]

Padding is handled via src_key_padding_mask: a boolean tensor of shape (B, N) where True marks padded (ignored) positions.

Parameters:

Name Type Description Default
token_dim int

Dimensionality of each input entity token. Defaults to ENTITY_TOKEN_DIM (16).

ENTITY_TOKEN_DIM
d_model int

Internal transformer dimension.

64
n_heads int

Number of attention heads. Must evenly divide d_model.

4
n_layers int

Number of transformer encoder layers.

2
dim_feedforward Optional[int]

Feed-forward sublayer width. Defaults to 4 * d_model.

None
dropout float

Dropout probability inside the transformer.

0.0
use_spatial_pe bool

When True, add a 2-D Fourier positional encoding derived from the position field (x, y) of each token.

True
n_freq_bands int

Number of Fourier frequency bands used by :class:SpatialPositionalEncoding when use_spatial_pe=True.

8

output_dim property

Dimensionality of the pooled output vector.

forward(tokens, pad_mask=None, return_attention=False)

Encode a batch of entity sequences.

Parameters:

Name Type Description Default
tokens Tensor

Entity token tensor of shape (B, N, token_dim).

required
pad_mask Optional[Tensor]

Boolean mask of shape (B, N). True marks padded (ignored) positions. If None, all positions are attended.

None
return_attention bool

When True, also return the averaged attention weights from the last transformer layer as a second return value of shape (B, N, N).

False

Returns:

Name Type Description
encoding torch.Tensor — shape ``(B, d_model)``

Mean-pooled sequence encoding.

attn_weights torch.Tensor — shape ``(B, N, N)``

Averaged attention weights from the last layer. Only returned when return_attention=True.

make_padding_mask(n_valid, max_n) staticmethod

Create a padding mask from per-sample entity counts.

Parameters:

Name Type Description Default
n_valid Tensor

Integer tensor of shape (B,) with the number of real (non-padded) entities in each sample.

required
max_n int

Total number of positions in the padded sequence.

required

Returns:

Name Type Description
pad_mask torch.BoolTensor — shape ``(B, max_n)``

True where the position is padded (ignored).

RecurrentActorCriticPolicy

Bases: Module

Actor-critic policy with LSTM temporal memory.

Both actor and critic run through the same :class:RecurrentEntityEncoder (weight-sharing optional). Hidden states are passed in/out explicitly so the caller controls episode boundaries and checkpointing.

Parameters:

Name Type Description Default
token_dim int

Entity token dimensionality.

ENTITY_TOKEN_DIM
action_dim int

Continuous action space dimensionality.

3
d_model int

Transformer internal dimension.

64
n_heads int

Attention heads in the entity encoder.

4
n_layers int

Transformer encoder layers.

2
lstm_hidden_size int

LSTM hidden state dimensionality.

128
lstm_num_layers int

Number of stacked LSTM layers.

1
actor_hidden_sizes Tuple[int, ...]

MLP hidden sizes on top of the LSTM output for the actor head.

(128, 64)
critic_hidden_sizes Tuple[int, ...]

MLP hidden sizes on top of the LSTM output for the critic head.

(128, 64)
shared_encoder bool

When True (default), actor and critic share the same :class:RecurrentEntityEncoder weights.

True
dropout float

Dropout probability.

0.0
use_spatial_pe bool

Enable 2-D Fourier positional encoding.

True

act(tokens, hx, pad_mask=None, deterministic=False)

Sample (or select deterministically) an action for a single step.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim) or (N, token_dim) for single sample.

required
hx LSTMHiddenState

Current LSTM hidden state.

required
pad_mask Optional[Tensor]

Padding mask.

None
deterministic bool

Return the distribution mean instead of sampling.

False

Returns:

Name Type Description
actions torch.Tensor — shape ``(B, action_dim)``
log_probs torch.Tensor — shape ``(B,)``
new_hx :class:`LSTMHiddenState` — updated hidden state

evaluate_actions(tokens_seq, hx, actions_seq, pad_mask_seq=None)

Evaluate log-probs, entropy and values for a sequence of actions.

Used during PPO update to compute the ratio π_new / π_old.

Parameters:

Name Type Description Default
tokens_seq Tensor

Entity token sequences of shape (B, T, N, token_dim).

required
hx LSTMHiddenState

Initial LSTM hidden state at the start of the sequence.

required
actions_seq Tensor

Actions to evaluate, shape (B, T, action_dim).

required
pad_mask_seq Optional[Tensor]

Padding mask of shape (B, T, N) or None.

None

Returns:

Name Type Description
log_probs torch.Tensor — shape ``(B, T)``
entropy torch.Tensor — shape ``(B, T)``
values torch.Tensor — shape ``(B, T)``

get_distribution(tokens, hx, pad_mask=None)

Compute action distribution for a single timestep.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim) or (N, token_dim) for unbatched.

required
hx LSTMHiddenState

Current LSTM hidden state.

required
pad_mask Optional[Tensor]

Boolean padding mask (B, N) or (N,) for unbatched.

None

Returns:

Name Type Description
dist :class:`~torch.distributions.Normal`
new_hx :class:`LSTMHiddenState`

get_value(tokens, hx, pad_mask=None)

Compute value estimates for a single timestep.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim) or (N, token_dim) for single sample.

required
hx LSTMHiddenState

Current LSTM hidden state.

required
pad_mask Optional[Tensor]

Padding mask.

None

Returns:

Name Type Description
values torch.Tensor — shape ``(B,)``
new_hx :class:`LSTMHiddenState`

initial_state(batch_size=1, device=None)

Return a zero-initialised LSTM hidden state.

Call at the start of each episode to reset temporal memory.

Parameters:

Name Type Description Default
batch_size int

Number of parallel environments / samples.

1
device Optional[device]

Target device.

None

load_checkpoint(path, device=None, **kwargs) classmethod

Restore a policy from a checkpoint created by :meth:save_checkpoint.

Parameters:

Name Type Description Default
path str | Path

Path to the checkpoint file.

required
device Optional[device]

Device to load weights onto.

None
**kwargs

Constructor arguments — must match those used when the checkpoint was saved.

{}

Returns:

Name Type Description
policy :class:`RecurrentActorCriticPolicy`

parameter_count()

Return a dict with actor and critic parameter counts.

save_checkpoint(path)

Persist model weights to path (torch.save format).

Parameters:

Name Type Description Default
path str | Path

Destination file path. Parent directories are created if needed.

required

WFM1Policy

Bases: Module

WFM-1 hierarchical transformer policy.

A single policy that operates across battalion, brigade, division, and corps echelons. Multi-echelon information is fused via cross-echelon attention. A lightweight ScenarioCard FiLM adapter conditions the policy on scenario metadata, enabling efficient fine-tuning.

Architecture::

For each active echelon e  {battalion, brigade, division, corps}:
  tokens_e (B, N_e, token_dim)
        
  EchelonEncoder(echelon=e)    enc_e (B, d_model)

Stack: echelon_encs = [enc_e₁, enc_e₂, ]  shape (B, E, d_model)
        
CrossEchelonTransformer    fused (B, d_model)
        
FiLM(ScenarioCard) :  fused  fused × γ + β
        
actor head    Gaussian action distribution
critic head   scalar value

Parameters:

Name Type Description Default
token_dim int

Entity token dimensionality. Defaults to :data:~models.entity_encoder.ENTITY_TOKEN_DIM.

ENTITY_TOKEN_DIM
action_dim int

Continuous action space dimensionality.

3
d_model int

Transformer hidden dimension.

128
n_heads int

Attention heads for both echelon encoder and cross-echelon transformer.

8
n_echelon_layers int

Transformer depth for each :class:EchelonEncoder.

4
n_cross_layers int

Transformer depth for :class:CrossEchelonTransformer.

2
actor_hidden_sizes Tuple[int, ...]

MLP hidden sizes for the actor head.

(256, 128)
critic_hidden_sizes Tuple[int, ...]

MLP hidden sizes for the critic head.

(256, 128)
dropout float

Dropout probability (transformer only).

0.0
use_spatial_pe bool

Enable 2-D Fourier positional encoding in the echelon encoders.

True
share_echelon_encoders bool

When True (default), all four echelon levels share a single :class:EchelonEncoder instance. When False each echelon has independent weights.

True
card_hidden_size int

Width of the FiLM adapter MLP.

64

act(tokens, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None, tokens_per_echelon=None, pad_masks=None, deterministic=False)

Sample actions (no gradient).

Can be called with a single-echelon tokens tensor or with a full tokens_per_echelon dict for multi-echelon inputs.

Parameters:

Name Type Description Default
tokens Optional[Tensor]

Single-echelon entity tokens (B, N, token_dim). Ignored when tokens_per_echelon is provided.

required
pad_mask Optional[Tensor]

Padding mask for tokens.

None
echelon int

Active echelon level for single-echelon mode.

ECHELON_BATTALION
card Optional[ScenarioCard]

Scenario conditioning card.

None
card_vec Optional[Tensor]

Pre-computed scenario card tensor (alternative to card).

None
tokens_per_echelon Optional[Dict[int, Tensor]]

Multi-echelon input dict; overrides tokens when given.

None
pad_masks Optional[Dict[int, Tensor]]

Padding masks for multi-echelon mode.

None
deterministic bool

When True, return the distribution mean instead of a sample.

False

Returns:

Name Type Description
actions torch.Tensor — shape ``(B, action_dim)``
log_probs torch.Tensor — shape ``(B,)``

adapter_parameters()

Return only the FiLM adapter parameters.

Call this to get a parameter group for lightweight scenario-specific fine-tuning (adapter-only gradient updates leave the base transformer frozen).

Returns:

Type Description
list of :class:`torch.nn.Parameter`

base_parameters()

Return all non-adapter (base model) parameters.

evaluate_actions(tokens, actions, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None)

Evaluate log-probs, entropy, and values for given actions.

Used during PPO update.

Returns:

Name Type Description
log_probs torch.Tensor — shape ``(B,)``
entropy torch.Tensor — shape ``(B,)``
values torch.Tensor — shape ``(B,)``

finetune_loss(batch)

Compute a supervised fine-tuning loss from a demonstration batch.

The batch dict must contain: * "tokens" — shape (B, N, token_dim) * "actions" — shape (B, action_dim)

Optional keys: * "pad_mask" — shape (B, N) * "echelon" — scalar int (default: ECHELON_BATTALION) * "card_vec" — shape (B, card_raw_dim) or (card_raw_dim,)

Returns:

Name Type Description
loss torch.Tensor — scalar behaviour-cloning MSE loss

freeze_base()

Freeze all base-model parameters; only the adapter remains trainable.

get_value(tokens, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None, tokens_per_echelon=None, pad_masks=None)

Compute scalar value estimates.

Returns:

Name Type Description
values torch.Tensor — shape ``(B,)``

load_checkpoint(path, map_location=None, **kwargs) classmethod

Load a WFM-1 checkpoint produced by :meth:save_checkpoint.

Extra keyword arguments override the saved configuration.

save_checkpoint(path)

Save model state dict to path (.pt).

unfreeze_base()

Unfreeze all base-model parameters.

Top-level training symbols

training

Wargames Training — training public API.

Stable interfaces for training runners, evaluation utilities, callbacks, self-play, curriculum management, policy registry, artifacts, Elo ratings, and benchmarks. Import from this module to remain insulated from internal restructuring.

Training runners

:func:train — train a single-agent policy with PPO. :class:TrainingConfig — configuration dataclass for :func:train.

Evaluation

:func:evaluate — quick win-rate evaluation. :func:evaluate_detailed — detailed win/draw/loss statistics. :func:run_episodes_with_model — low-level episode runner. :class:EvaluationResult — structured evaluation result.

Callbacks

:class:WandbCallback — log rollout statistics to W&B. :class:RewardBreakdownCallback — per-component reward logging. :class:EloEvalCallback — Elo evaluation callback. :class:ManifestCheckpointCallback — manifest-aware checkpoint saving. :class:ManifestEvalCallback — manifest-aware best-model saving.

Self-play

:class:OpponentPool — pool of frozen single-agent policy snapshots. :class:SelfPlayCallback — snapshot the current policy into the pool. :class:WinRateVsPoolCallback — evaluate win-rate vs. the pool. :func:evaluate_vs_pool — standalone pool win-rate helper. :class:TeamOpponentPool — pool for MAPPO team self-play. :func:evaluate_team_vs_pool — evaluate a MAPPO policy vs. pool. :func:nash_exploitability_proxy — exploitability proxy metric.

Curriculum

:class:CurriculumScheduler — win-rate-based curriculum progression. :class:CurriculumStage — curriculum stage enum. :func:load_v1_weights_into_mappo — warm-start MAPPO from a v1 checkpoint.

Policy registry

:class:PolicyRegistry — versioned multi-echelon checkpoint registry. :class:Echelon — echelon enum (battalion / brigade / division). :class:PolicyEntry — registry entry named-tuple.

Artifacts

:class:CheckpointManifest — append-only JSONL checkpoint index. :func:checkpoint_name_prefix — canonical periodic-checkpoint prefix. :func:checkpoint_final_stem — canonical final-checkpoint stem. :func:checkpoint_best_filename — canonical best-checkpoint filename. :func:parse_step_from_checkpoint_name — extract step number from filename.

Elo

:class:EloRegistry — Elo rating registry. :class:TeamEloRegistry — multi-agent team Elo registry. :data:DEFAULT_RATING — default rating for unseen agents. :data:BASELINE_RATINGS — fixed ratings for scripted baseline opponents.

Benchmarks

:class:WFM1Benchmark — WFM-1 zero-shot evaluation. :class:WFM1BenchmarkConfig — WFM-1 benchmark configuration. :class:WFM1BenchmarkResult — WFM-1 per-scenario result. :class:WFM1BenchmarkSummary — WFM-1 aggregate results. :class:TransferBenchmark — GIS terrain transfer benchmark. :class:TransferEvalConfig — transfer benchmark configuration. :class:TransferResult — transfer benchmark per-condition result. :class:TransferSummary — transfer benchmark aggregate results. :class:HistoricalBenchmark — historical battle fidelity benchmark. :class:BenchmarkEntry — per-battle benchmark entry. :class:BenchmarkSummary — historical benchmark aggregate summary.

TrainingConfig dataclass

Configuration for a single PPO training run on :class:~envs.battalion_env.BattalionEnv.

All fields are optional; the defaults match configs/default.yaml. Instances can be passed directly to :func:train or individual fields can be overridden via **kwargs at the call site.

Examples:

Minimal run with defaults::

from training import train, TrainingConfig
model = train(TrainingConfig(total_timesteps=500_000))

Override specific fields at call time without constructing a config::

model = train(total_timesteps=200_000, n_envs=4, enable_wandb=False)

OpponentPool

Fixed-size pool of frozen PPO policy snapshots.

Snapshots are stored as Stable-Baselines3 .zip files under pool_dir. The pool keeps at most max_size snapshots; when full, the oldest snapshot is evicted to make room for the newest one.

Parameters:

Name Type Description Default
pool_dir str | Path

Directory where snapshot .zip files are persisted. Created on first :meth:add if it does not already exist.

required
max_size int

Maximum number of snapshots to retain (default 10).

10

Attributes:

Name Type Description
pool_dir Path

Resolved path of the snapshot directory.

max_size int

Maximum number of snapshots retained in the pool.

size property

Current number of snapshots in the pool.

snapshot_paths property

Ordered list of snapshot paths (oldest first, read-only copy).

add(model, version)

Save model as a new snapshot and add it to the pool.

If the pool already contains max_size snapshots the oldest is removed from disk before saving the new one.

Parameters:

Name Type Description Default
model PPO

The current PPO model to snapshot.

required
version int

Monotonically increasing version number embedded in the snapshot file name for traceability.

required

Returns:

Type Description
Path

Path of the newly saved snapshot file (including .zip).

sample(rng=None)

Load and return a uniformly sampled snapshot as a PPO model.

Parameters:

Name Type Description Default
rng Optional[Generator]

Optional NumPy random generator for reproducible sampling. When None, the pool's internal shared RNG instance is used (seeded once at pool construction time).

None

Returns:

Type Description
PPO or None

A loaded PPO model, or None if the pool is empty.

sample_latest()

Load and return the most recently added snapshot.

Returns:

Type Description
PPO or None

The latest PPO model, or None if the pool is empty.

EloRegistry

Persistent Elo rating registry backed by a JSON file.

Stores per-agent ratings and game counts. Scripted baseline opponents ("scripted_l1""scripted_l5", "random") have fixed seed ratings defined in :data:BASELINE_RATINGS and are never modified.

Parameters:

Name Type Description Default
path Union[str, Path, None]

Path to the JSON file used for persistence. The parent directory is created automatically on :meth:save. Pass None to create an in-memory registry that cannot be saved to disk.

'checkpoints/elo_registry.json'

can_save property

True when the registry has a backing file and can be persisted.

all_ratings()

Return a copy of all stored ratings (excludes pure baselines).

get_game_count(name)

Return the total number of rated games played by name.

get_rating(name)

Return the current Elo rating for name.

Falls back to :data:BASELINE_RATINGS for scripted opponents, then to :data:DEFAULT_RATING for completely unknown agents.

save()

Persist current ratings and game counts to the JSON file.

Raises:

Type Description
ValueError

If the registry was created without a file path (path=None).

update(agent, opponent, outcome, n_games=1)

Update the Elo rating of agent after a batch of n_games.

The outcome is the average score per game:

  • 1.0 — all wins
  • 0.5 — all draws
  • 0.0 — all losses

Scripted baseline opponents' ratings are never modified.

Parameters:

Name Type Description Default
agent str

Identifier for the agent whose rating is updated.

required
opponent str

Identifier of the opponent played against.

required
outcome float

Average per-game score in [0, 1].

required
n_games int

Number of games in this batch (used to advance the game-count counter; the K-factor is evaluated at the pre-update count).

1

Returns:

Type Description
float

Elo rating delta for agent (positive = rating increased).

Raises:

Type Description
ValueError

If outcome is outside [0, 1], n_games < 1, or agent is a key in :data:BASELINE_RATINGS (baselines are immutable).