API Reference¶
Full auto-generated API documentation for all stable public interfaces. All four packages expose their symbols at the top-level namespace so you only need a single import per package.
from envs import BattalionEnv, RewardWeights, Formation
from models import BattalionMlpPolicy, MAPPOPolicy, WFM1Policy
from training import train, TrainingConfig, evaluate, OpponentPool
from analysis import COAGenerator, SaliencyAnalyzer, compute_gradient_saliency
Packages¶
- envs — simulation environments, reward shaping, sim primitives, HRL options
- models — MLP, MAPPO, transformer, recurrent, and WFM-1 policies
- training — training runners, callbacks, evaluation, self-play, Elo, benchmarks
- analysis — COA generation and policy saliency
- benchmarks — WargamesBench runner
Top-level envs symbols¶
envs
¶
Wargames Training — environments public API.
Stable interfaces for all simulation environments, reward shaping, simulation primitives, and scenario utilities. Import from this module rather than from individual submodules to stay insulated from internal restructuring.
Environments
:class:BattalionEnv
Single-battalion 1v1 Gymnasium environment (continuous action space).
:class:BrigadeEnv
Brigade-level multi-battalion environment (MultiDiscrete actions).
:class:DivisionEnv
Division-level multi-brigade environment (MultiDiscrete actions).
:class:CorpsEnv
Corps-level operational environment with supply and road networks.
:class:CavalryCorpsEnv
Corps environment extended with cavalry reconnaissance.
:class:ArtilleryCorpsEnv
Corps environment extended with artillery and fortification mechanics.
:class:MultiBattalionEnv
Vectorised multi-battalion PettingZoo parallel environment.
Reward shaping
:class:RewardWeights
Dataclass of per-component reward multipliers.
:class:RewardComponents
Named-tuple of per-step reward component values.
:func:compute_reward
Compute a single-step reward given state deltas and weights.
Environment configuration types
:class:LogisticsConfig / :class:LogisticsState / :class:SupplyWagon
Logistics and supply management.
:class:MoraleConfig
Morale dynamics parameters.
:class:Formation
Unit formation enum (LINE / COLUMN / SQUARE / SKIRMISH).
:class:WeatherConfig / :class:WeatherState
Weather simulation parameters.
:class:RedPolicy
Protocol for scripted / learned Red opponent policies.
Simulation primitives
:class:SimEngine
Headless deterministic 1v1 simulation engine.
:class:EpisodeResult
Structured result from a :class:SimEngine episode.
Corps constants
:data:CORPS_OBS_DIM, :data:N_CORPS_SECTORS, :data:N_OBJECTIVES,
:data:CORPS_MAP_WIDTH, :data:CORPS_MAP_HEIGHT, :data:N_ROAD_FEATURES
BattalionEnv
¶
Bases: Env
1v1 battalion RL environment.
The agent controls the Blue battalion; Red is driven either by a
built-in scripted opponent (controlled by curriculum_level) or by an
optional red_policy (any object with a predict method, e.g. a
Stable-Baselines3 PPO model). When red_policy is provided it takes
precedence over the scripted opponent and curriculum_level is ignored for
movement and fire decisions.
Curriculum levels (scripted opponent only) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ===== ============================================= Level Red opponent behaviour ===== ============================================= 1 Stationary — Red does not move or fire. 2 Turning only — Red faces Blue but stays put. 3 Advance only — Red turns and advances; no fire. 4 Soft fire — Red turns, advances, fires at 50 % intensity. 5 Full combat — Red turns, advances, fires at 100 % intensity (default). ===== =============================================
Observation space — Box(shape=(17,), dtype=float32) (formations disabled)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
===== =========================== =========
Index Feature Range
===== =========================== =========
0 blue x / map_width [0, 1]
1 blue y / map_height [0, 1]
2 cos(blue θ) [-1, 1]
3 sin(blue θ) [-1, 1]
4 blue strength [0, 1]
5 blue morale [0, 1]
6 distance to red / diagonal [0, 1]
7 cos(bearing to red) [-1, 1]
8 sin(bearing to red) [-1, 1]
9 red strength [0, 1]
10 red morale [0, 1]
11 step / max_steps [0, 1]
12 blue elevation (normalised) [0, 1]
13 blue cover [0, 1]
14 red elevation (normalised) [0, 1]
15 red cover [0, 1]
16 line-of-sight (0=blocked) [0, 1]
===== =========================== =========
When enable_formations is True, two extra dimensions are appended:
===== ===================================== ========= Index Feature Range ===== ===================================== ========= 17 blue formation / (NUM_FORMATIONS−1) [0, 1] 18 blue transitioning (1=yes) [0, 1] ===== ===================================== =========
When enable_logistics is True, three more dimensions follow the
terrain / formations dims:
===== =========================== ========= Index Feature Range ===== =========================== ========= N+0 blue ammo level [0, 1] N+1 blue food level [0, 1] N+2 blue fatigue level [0, 1] ===== =========================== =========
(where N = 17 with formations disabled or 19 with formations enabled).
When enable_weather is True, two more dimensions are appended after
any logistics dims:
===== =========================== ========= Index Feature Range ===== =========================== ========= M+0 weather condition id [0, 1] M+1 combined visibility fraction [0, 1] ===== =========================== =========
(where M = N + 3 when logistics are enabled, or N when disabled).
Action space — Box(shape=(3,), dtype=float32) (formations disabled)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
===== =========== ======== ============================================
Index Name Range Effect
===== =========== ======== ============================================
0 move [-1, 1] Scale max_speed; positive = forward
1 rotate [-1, 1] Scale max_turn_rate; positive = CCW
2 fire [0, 1] Fire intensity this step
===== =========== ======== ============================================
When enable_formations is True, a fourth dimension is added:
===== =========== ======== ============================================ Index Name Range Effect ===== =========== ======== ============================================ 3 formation [0, 3] Desired formation (rounded to nearest int) ===== =========== ======== ============================================
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
map_width
|
float
|
Map dimensions in metres (default 1 km × 1 km). |
MAP_WIDTH
|
map_height
|
float
|
Map dimensions in metres (default 1 km × 1 km). |
MAP_WIDTH
|
max_steps
|
int
|
Episode length cap (default 500). |
MAX_STEPS
|
terrain
|
Optional[TerrainMap]
|
Optional :class: |
None
|
randomize_terrain
|
bool
|
When |
True
|
hill_speed_factor
|
float
|
Movement speed multiplier applied to units on maximum-elevation
terrain. Must be in |
0.5
|
curriculum_level
|
int
|
Scripted Red opponent difficulty (1–5). Level 1 is the easiest
(stationary target); level 5 is full combat. Defaults to |
5
|
reward_weights
|
Optional[RewardWeights]
|
:class: |
None
|
red_policy
|
Optional[RedPolicy]
|
Optional policy object for driving the Red battalion. Must expose
a |
None
|
render_mode
|
Optional[str]
|
Render mode. |
None
|
enable_formations
|
bool
|
When |
False
|
enable_logistics
|
bool
|
When |
False
|
logistics_config
|
Optional[LogisticsConfig]
|
Optional :class: |
None
|
enable_weather
|
bool
|
When |
False
|
weather_config
|
Optional[WeatherConfig]
|
Optional :class: |
None
|
close()
¶
Clean up resources, including the pygame window if open.
render()
¶
Render the current environment state.
When render_mode="human" a pygame window is opened on the first
call and kept alive for subsequent calls. The window is closed by
:meth:close. When render_mode is None this is a no-op.
reset(*, seed=None, options=None)
¶
Reset the environment and return the initial observation.
When randomize_terrain is True (the default when no fixed
terrain map is passed to __init__), a new procedural terrain
is generated from the seeded RNG on every call. Passing the same
seed therefore always produces the same terrain layout and unit
positions.
Blue spawns in the western half of the map facing roughly east; Red spawns in the eastern half facing roughly west.
set_red_policy(policy)
¶
Swap the Red opponent policy at runtime.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
Optional[RedPolicy]
|
New policy to use for Red, or |
required |
step(action)
¶
Advance the environment by one step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action
|
ndarray
|
Array of shape |
required |
Returns:
| Type | Description |
|---|---|
(observation, reward, terminated, truncated, info)
|
|
RewardWeights
dataclass
¶
Configurable multipliers for each reward component.
All weights are applied by multiplying against their respective raw
component value before summing into the total reward. Set a weight
to 0.0 to disable that component entirely.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
delta_enemy_strength
|
float
|
Multiplier applied to the fraction of enemy strength destroyed in a
step ( |
5.0
|
delta_own_strength
|
float
|
Multiplier applied to the fraction of own strength lost in a step
( |
5.0
|
survival_bonus
|
float
|
Per-step bonus scaled by Blue's current strength. Set to |
0.0
|
win_bonus
|
float
|
Terminal reward added when Blue wins (Red routed or destroyed). |
10.0
|
loss_penalty
|
float
|
Terminal reward added when Blue loses (Blue routed or destroyed). Should be negative. |
-10.0
|
time_penalty
|
float
|
Constant added every step. A small negative value (e.g. |
-0.01
|
enemy_routed_bonus = 0.0
class-attribute
instance-attribute
¶
Bonus added each step that the enemy is in a routing state.
A positive value (e.g. 2.0) encourages pursuit of routing enemies and
exploitation of broken units. Set to 0.0 (default) to disable.
own_routing_penalty = 0.0
class-attribute
instance-attribute
¶
Penalty added each step that the agent's own unit is routing.
A negative value (e.g. -2.0) discourages allowing the agent's
battalion to be broken. Set to 0.0 (default) to disable.
Formation
¶
Bases: IntEnum
Discrete formation states for a Napoleonic infantry battalion.
The integer values are stable indices used in observation vectors and action spaces; do not change them without updating downstream code.
========== === ==================================================== Formation Int Historical role ========== === ==================================================== LINE 0 Two- or three-rank line — maximum firepower front. COLUMN 1 Attack or march column — highest movement speed. SQUARE 2 Hollow square — impenetrable to unsupported cavalry. SKIRMISH 3 Extended order — loose screen, independent aimed fire. ========== === ====================================================
SimEngine
¶
Run a 1v1 battalion episode to completion.
Each :meth:step:
- Resets per-step damage accumulators on both :class:
CombatStateobjects. - Computes fire damage from both sides simultaneously (before applying either, so there are no ordering effects on damage calculation).
- Applies terrain cover to reduce incoming damage at each unit's position.
- Applies casualties and updates strength.
- Runs a morale check for each unit, potentially triggering routing.
The episode ends (:meth:is_over returns True) when:
- Either side's :class:
CombatStatereportsis_routing = True, or - Either side's strength falls to :data:
DESTROYED_THRESHOLDor below, or step_countreachesmax_steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
blue
|
Battalion
|
The two opposing battalions. Modified in-place by each step. |
required |
red
|
Battalion
|
The two opposing battalions. Modified in-place by each step. |
required |
terrain
|
Optional[TerrainMap]
|
Optional :class: |
None
|
max_steps
|
int
|
Hard cap on episode length (default 500, matching acceptance criterion AC-1 of Epic E1.2). |
500
|
rng
|
Optional[Generator]
|
Seeded random generator. Defaults to a fresh unseeded generator. Pass a seeded generator for reproducible results. |
None
|
is_over()
¶
Return True when the episode should end.
run()
¶
Run the episode to completion and return a result summary.
step()
¶
Advance one simulation step.
Returns:
| Type | Description |
|---|---|
dict with keys:
|
|
Top-level models symbols¶
models
¶
Neural network models for wargames_training.
Exports
MLP policy (SB3-compatible)
:class:~models.mlp_policy.BattalionMlpPolicy — MLP actor-critic for
:class:~envs.battalion_env.BattalionEnv.
MAPPO multi-agent policy
:class:~models.mappo_policy.MAPPOActor — stochastic Gaussian actor.
:class:~models.mappo_policy.MAPPOCritic — centralised value critic.
:class:~models.mappo_policy.MAPPOPolicy — combined actor-critic policy.
Entity encoder (transformer)
:data:~models.entity_encoder.ENTITY_TOKEN_DIM — token dimensionality.
:class:~models.entity_encoder.SpatialPositionalEncoding — 2D Fourier encoding.
:class:~models.entity_encoder.EntityEncoder — multi-head transformer encoder.
:class:~models.entity_encoder.EntityActorCriticPolicy — entity-based actor-critic.
Recurrent policy (LSTM)
:class:~models.recurrent_policy.LSTMHiddenState — LSTM hidden-state container.
:class:~models.recurrent_policy.RecurrentEntityEncoder — entity encoder + LSTM.
:class:~models.recurrent_policy.RecurrentActorCriticPolicy — LSTM actor-critic.
:class:~models.recurrent_policy.RecurrentRolloutBuffer — BPTT rollout buffer.
WFM-1 foundation model
:class:~models.wfm1.WFM1Policy — multi-echelon foundation model.
:class:~models.wfm1.ScenarioCard — scenario descriptor for WFM-1.
:class:~models.wfm1.EchelonEncoder — per-echelon entity encoder.
:class:~models.wfm1.CrossEchelonTransformer — cross-echelon transformer.
BattalionMlpPolicy
¶
Bases: ActorCriticPolicy
MLP actor-critic policy for BattalionEnv.
A configurable fully-connected network with separate actor and critic
heads, designed for the 12-dimensional observation space of
:class:~envs.battalion_env.BattalionEnv.
Architecture (default)::
obs(12) → Linear(128) → Tanh → Linear(128) → Tanh
↓ ↓
actor head → action(3) + log_std critic head → value(1)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observation_space
|
Space
|
Gymnasium observation space. |
required |
action_space
|
Space
|
Gymnasium action space. |
required |
lr_schedule
|
Schedule
|
Learning-rate schedule passed in by the SB3 algorithm. |
required |
net_arch
|
Optional[List[int]]
|
Hidden-layer sizes shared by actor and critic. Defaults to
|
None
|
activation_fn
|
Type[Module]
|
Activation function class applied after each hidden layer.
Defaults to :class: |
Tanh
|
**kwargs
|
Any
|
Forwarded to
:class: |
{}
|
MAPPOPolicy
¶
Bases: Module
MAPPO policy: shared actor(s) plus a centralized critic.
Supports two parameter-sharing modes for the actor:
share_parameters=True(default) — all n_agents agents use the same :class:MAPPOActorweights. Memory scales as O(actor_params + critic_params) instead of O(n_agents * actor_params).share_parameters=False— each agent gets its own :class:MAPPOActor; useful for ablation studies.
In both cases there is a single :class:MAPPOCritic that all
agents share.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_dim
|
int
|
Per-agent local observation dimensionality. |
required |
action_dim
|
int
|
Per-agent action dimensionality (continuous). |
required |
state_dim
|
int
|
Global state dimensionality. |
required |
n_agents
|
int
|
Number of controlled agents (used only when
|
1
|
share_parameters
|
bool
|
Whether all agents share one actor. Defaults to |
True
|
actor_hidden_sizes
|
Tuple[int, ...]
|
Hidden layer sizes for the actor trunk. |
(128, 64)
|
critic_hidden_sizes
|
Tuple[int, ...]
|
Hidden layer sizes for the critic trunk. |
(128, 64)
|
act(obs, agent_idx=0, deterministic=False)
¶
Sample actions for a batch of observations.
A batch dimension is added internally when obs is 1-D so that
the return shapes are always (batch, action_dim) and
(batch,) regardless of whether the caller passes a single
observation or a batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
Tensor
|
Local observations of shape |
required |
agent_idx
|
int
|
Index of the agent whose actor should be used. Ignored when
|
0
|
deterministic
|
bool
|
When |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
actions |
torch.Tensor — shape ``(batch, action_dim)``
|
|
log_probs |
torch.Tensor — shape ``(batch,)``
|
|
evaluate_actions(obs, actions, state, agent_idx=0)
¶
Evaluate actions under the current policy for the PPO update.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
Tensor
|
Local observations of shape |
required |
actions
|
Tensor
|
Actions of shape |
required |
state
|
Tensor
|
Global states of shape |
required |
agent_idx
|
int
|
Actor index (ignored when sharing parameters). |
0
|
Returns:
| Name | Type | Description |
|---|---|---|
log_probs |
torch.Tensor — shape ``(batch,)``
|
|
entropy |
torch.Tensor — shape ``(batch,)``
|
|
values |
torch.Tensor — shape ``(batch,)``
|
|
get_actor(agent_idx=0)
¶
Return the actor for agent_idx.
When share_parameters=True the same actor is returned
regardless of agent_idx.
get_value(state)
¶
Return value estimate(s) for the given global state(s).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
Tensor
|
Global state tensor of shape |
required |
Returns:
| Name | Type | Description |
|---|---|---|
values |
torch.Tensor — scalar or shape ``(batch,)``
|
|
parameter_count()
¶
Return a dict with actor and critic parameter counts.
EntityEncoder
¶
Bases: Module
Multi-head self-attention encoder over a variable-length entity sequence.
Architecture::
entity tokens (B, N, token_dim)
│
token_embed: Linear(token_dim → d_model)
│ + SpatialPositionalEncoding(x, y) [optional]
│
TransformerEncoder (n_layers × TransformerEncoderLayer)
└── MultiheadAttention (n_heads, d_model)
└── FFN (dim_feedforward = 4 * d_model)
└── LayerNorm + residual
│
mean-pool over non-padded entities → (B, d_model)
│
output projection: Linear(d_model → d_model) [identity-initialised]
Padding is handled via src_key_padding_mask: a boolean tensor of shape
(B, N) where True marks padded (ignored) positions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
token_dim
|
int
|
Dimensionality of each input entity token. Defaults to
|
ENTITY_TOKEN_DIM
|
d_model
|
int
|
Internal transformer dimension. |
64
|
n_heads
|
int
|
Number of attention heads. Must evenly divide |
4
|
n_layers
|
int
|
Number of transformer encoder layers. |
2
|
dim_feedforward
|
Optional[int]
|
Feed-forward sublayer width. Defaults to |
None
|
dropout
|
float
|
Dropout probability inside the transformer. |
0.0
|
use_spatial_pe
|
bool
|
When |
True
|
n_freq_bands
|
int
|
Number of Fourier frequency bands used by
:class: |
8
|
output_dim
property
¶
Dimensionality of the pooled output vector.
forward(tokens, pad_mask=None, return_attention=False)
¶
Encode a batch of entity sequences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens
|
Tensor
|
Entity token tensor of shape |
required |
pad_mask
|
Optional[Tensor]
|
Boolean mask of shape |
None
|
return_attention
|
bool
|
When |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
encoding |
torch.Tensor — shape ``(B, d_model)``
|
Mean-pooled sequence encoding. |
attn_weights |
torch.Tensor — shape ``(B, N, N)``
|
Averaged attention weights from the last layer.
Only returned when |
make_padding_mask(n_valid, max_n)
staticmethod
¶
Create a padding mask from per-sample entity counts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_valid
|
Tensor
|
Integer tensor of shape |
required |
max_n
|
int
|
Total number of positions in the padded sequence. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
pad_mask |
torch.BoolTensor — shape ``(B, max_n)``
|
|
RecurrentActorCriticPolicy
¶
Bases: Module
Actor-critic policy with LSTM temporal memory.
Both actor and critic run through the same :class:RecurrentEntityEncoder
(weight-sharing optional). Hidden states are passed in/out explicitly so
the caller controls episode boundaries and checkpointing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
token_dim
|
int
|
Entity token dimensionality. |
ENTITY_TOKEN_DIM
|
action_dim
|
int
|
Continuous action space dimensionality. |
3
|
d_model
|
int
|
Transformer internal dimension. |
64
|
n_heads
|
int
|
Attention heads in the entity encoder. |
4
|
n_layers
|
int
|
Transformer encoder layers. |
2
|
lstm_hidden_size
|
int
|
LSTM hidden state dimensionality. |
128
|
lstm_num_layers
|
int
|
Number of stacked LSTM layers. |
1
|
actor_hidden_sizes
|
Tuple[int, ...]
|
MLP hidden sizes on top of the LSTM output for the actor head. |
(128, 64)
|
critic_hidden_sizes
|
Tuple[int, ...]
|
MLP hidden sizes on top of the LSTM output for the critic head. |
(128, 64)
|
shared_encoder
|
bool
|
When |
True
|
dropout
|
float
|
Dropout probability. |
0.0
|
use_spatial_pe
|
bool
|
Enable 2-D Fourier positional encoding. |
True
|
act(tokens, hx, pad_mask=None, deterministic=False)
¶
Sample (or select deterministically) an action for a single step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens
|
Tensor
|
Shape |
required |
hx
|
LSTMHiddenState
|
Current LSTM hidden state. |
required |
pad_mask
|
Optional[Tensor]
|
Padding mask. |
None
|
deterministic
|
bool
|
Return the distribution mean instead of sampling. |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
actions |
torch.Tensor — shape ``(B, action_dim)``
|
|
log_probs |
torch.Tensor — shape ``(B,)``
|
|
new_hx |
:class:`LSTMHiddenState` — updated hidden state
|
|
evaluate_actions(tokens_seq, hx, actions_seq, pad_mask_seq=None)
¶
Evaluate log-probs, entropy and values for a sequence of actions.
Used during PPO update to compute the ratio π_new / π_old.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens_seq
|
Tensor
|
Entity token sequences of shape |
required |
hx
|
LSTMHiddenState
|
Initial LSTM hidden state at the start of the sequence. |
required |
actions_seq
|
Tensor
|
Actions to evaluate, shape |
required |
pad_mask_seq
|
Optional[Tensor]
|
Padding mask of shape |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
log_probs |
torch.Tensor — shape ``(B, T)``
|
|
entropy |
torch.Tensor — shape ``(B, T)``
|
|
values |
torch.Tensor — shape ``(B, T)``
|
|
get_distribution(tokens, hx, pad_mask=None)
¶
Compute action distribution for a single timestep.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens
|
Tensor
|
Shape |
required |
hx
|
LSTMHiddenState
|
Current LSTM hidden state. |
required |
pad_mask
|
Optional[Tensor]
|
Boolean padding mask |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
dist |
:class:`~torch.distributions.Normal`
|
|
new_hx |
:class:`LSTMHiddenState`
|
|
get_value(tokens, hx, pad_mask=None)
¶
Compute value estimates for a single timestep.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens
|
Tensor
|
Shape |
required |
hx
|
LSTMHiddenState
|
Current LSTM hidden state. |
required |
pad_mask
|
Optional[Tensor]
|
Padding mask. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
values |
torch.Tensor — shape ``(B,)``
|
|
new_hx |
:class:`LSTMHiddenState`
|
|
initial_state(batch_size=1, device=None)
¶
Return a zero-initialised LSTM hidden state.
Call at the start of each episode to reset temporal memory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Number of parallel environments / samples. |
1
|
device
|
Optional[device]
|
Target device. |
None
|
load_checkpoint(path, device=None, **kwargs)
classmethod
¶
Restore a policy from a checkpoint created by :meth:save_checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Path to the checkpoint file. |
required |
device
|
Optional[device]
|
Device to load weights onto. |
None
|
**kwargs
|
Constructor arguments — must match those used when the checkpoint was saved. |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
policy |
:class:`RecurrentActorCriticPolicy`
|
|
parameter_count()
¶
Return a dict with actor and critic parameter counts.
save_checkpoint(path)
¶
Persist model weights to path (torch.save format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Destination file path. Parent directories are created if needed. |
required |
WFM1Policy
¶
Bases: Module
WFM-1 hierarchical transformer policy.
A single policy that operates across battalion, brigade, division, and corps echelons. Multi-echelon information is fused via cross-echelon attention. A lightweight ScenarioCard FiLM adapter conditions the policy on scenario metadata, enabling efficient fine-tuning.
Architecture::
For each active echelon e ∈ {battalion, brigade, division, corps}:
tokens_e (B, N_e, token_dim)
│
EchelonEncoder(echelon=e) → enc_e (B, d_model)
Stack: echelon_encs = [enc_e₁, enc_e₂, …] shape (B, E, d_model)
│
CrossEchelonTransformer → fused (B, d_model)
│
FiLM(ScenarioCard) : fused ← fused × γ + β
│
actor head → Gaussian action distribution
critic head → scalar value
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
token_dim
|
int
|
Entity token dimensionality. Defaults to :data: |
ENTITY_TOKEN_DIM
|
action_dim
|
int
|
Continuous action space dimensionality. |
3
|
d_model
|
int
|
Transformer hidden dimension. |
128
|
n_heads
|
int
|
Attention heads for both echelon encoder and cross-echelon transformer. |
8
|
n_echelon_layers
|
int
|
Transformer depth for each :class: |
4
|
n_cross_layers
|
int
|
Transformer depth for :class: |
2
|
actor_hidden_sizes
|
Tuple[int, ...]
|
MLP hidden sizes for the actor head. |
(256, 128)
|
critic_hidden_sizes
|
Tuple[int, ...]
|
MLP hidden sizes for the critic head. |
(256, 128)
|
dropout
|
float
|
Dropout probability (transformer only). |
0.0
|
use_spatial_pe
|
bool
|
Enable 2-D Fourier positional encoding in the echelon encoders. |
True
|
share_echelon_encoders
|
bool
|
When |
True
|
card_hidden_size
|
int
|
Width of the FiLM adapter MLP. |
64
|
act(tokens, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None, tokens_per_echelon=None, pad_masks=None, deterministic=False)
¶
Sample actions (no gradient).
Can be called with a single-echelon tokens tensor or with a
full tokens_per_echelon dict for multi-echelon inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens
|
Optional[Tensor]
|
Single-echelon entity tokens |
required |
pad_mask
|
Optional[Tensor]
|
Padding mask for |
None
|
echelon
|
int
|
Active echelon level for single-echelon mode. |
ECHELON_BATTALION
|
card
|
Optional[ScenarioCard]
|
Scenario conditioning card. |
None
|
card_vec
|
Optional[Tensor]
|
Pre-computed scenario card tensor (alternative to |
None
|
tokens_per_echelon
|
Optional[Dict[int, Tensor]]
|
Multi-echelon input dict; overrides |
None
|
pad_masks
|
Optional[Dict[int, Tensor]]
|
Padding masks for multi-echelon mode. |
None
|
deterministic
|
bool
|
When |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
actions |
torch.Tensor — shape ``(B, action_dim)``
|
|
log_probs |
torch.Tensor — shape ``(B,)``
|
|
adapter_parameters()
¶
Return only the FiLM adapter parameters.
Call this to get a parameter group for lightweight scenario-specific fine-tuning (adapter-only gradient updates leave the base transformer frozen).
Returns:
| Type | Description |
|---|---|
list of :class:`torch.nn.Parameter`
|
|
base_parameters()
¶
Return all non-adapter (base model) parameters.
evaluate_actions(tokens, actions, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None)
¶
Evaluate log-probs, entropy, and values for given actions.
Used during PPO update.
Returns:
| Name | Type | Description |
|---|---|---|
log_probs |
torch.Tensor — shape ``(B,)``
|
|
entropy |
torch.Tensor — shape ``(B,)``
|
|
values |
torch.Tensor — shape ``(B,)``
|
|
finetune_loss(batch)
¶
Compute a supervised fine-tuning loss from a demonstration batch.
The batch dict must contain:
* "tokens" — shape (B, N, token_dim)
* "actions" — shape (B, action_dim)
Optional keys:
* "pad_mask" — shape (B, N)
* "echelon" — scalar int (default: ECHELON_BATTALION)
* "card_vec" — shape (B, card_raw_dim) or (card_raw_dim,)
Returns:
| Name | Type | Description |
|---|---|---|
loss |
torch.Tensor — scalar behaviour-cloning MSE loss
|
|
freeze_base()
¶
Freeze all base-model parameters; only the adapter remains trainable.
get_value(tokens, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None, tokens_per_echelon=None, pad_masks=None)
¶
Compute scalar value estimates.
Returns:
| Name | Type | Description |
|---|---|---|
values |
torch.Tensor — shape ``(B,)``
|
|
load_checkpoint(path, map_location=None, **kwargs)
classmethod
¶
Load a WFM-1 checkpoint produced by :meth:save_checkpoint.
Extra keyword arguments override the saved configuration.
save_checkpoint(path)
¶
Save model state dict to path (.pt).
unfreeze_base()
¶
Unfreeze all base-model parameters.
Top-level training symbols¶
training
¶
Wargames Training — training public API.
Stable interfaces for training runners, evaluation utilities, callbacks, self-play, curriculum management, policy registry, artifacts, Elo ratings, and benchmarks. Import from this module to remain insulated from internal restructuring.
Training runners
:func:train — train a single-agent policy with PPO.
:class:TrainingConfig — configuration dataclass for :func:train.
Evaluation
:func:evaluate — quick win-rate evaluation.
:func:evaluate_detailed — detailed win/draw/loss statistics.
:func:run_episodes_with_model — low-level episode runner.
:class:EvaluationResult — structured evaluation result.
Callbacks
:class:WandbCallback — log rollout statistics to W&B.
:class:RewardBreakdownCallback — per-component reward logging.
:class:EloEvalCallback — Elo evaluation callback.
:class:ManifestCheckpointCallback — manifest-aware checkpoint saving.
:class:ManifestEvalCallback — manifest-aware best-model saving.
Self-play
:class:OpponentPool — pool of frozen single-agent policy snapshots.
:class:SelfPlayCallback — snapshot the current policy into the pool.
:class:WinRateVsPoolCallback — evaluate win-rate vs. the pool.
:func:evaluate_vs_pool — standalone pool win-rate helper.
:class:TeamOpponentPool — pool for MAPPO team self-play.
:func:evaluate_team_vs_pool — evaluate a MAPPO policy vs. pool.
:func:nash_exploitability_proxy — exploitability proxy metric.
Curriculum
:class:CurriculumScheduler — win-rate-based curriculum progression.
:class:CurriculumStage — curriculum stage enum.
:func:load_v1_weights_into_mappo — warm-start MAPPO from a v1 checkpoint.
Policy registry
:class:PolicyRegistry — versioned multi-echelon checkpoint registry.
:class:Echelon — echelon enum (battalion / brigade / division).
:class:PolicyEntry — registry entry named-tuple.
Artifacts
:class:CheckpointManifest — append-only JSONL checkpoint index.
:func:checkpoint_name_prefix — canonical periodic-checkpoint prefix.
:func:checkpoint_final_stem — canonical final-checkpoint stem.
:func:checkpoint_best_filename — canonical best-checkpoint filename.
:func:parse_step_from_checkpoint_name — extract step number from filename.
Elo
:class:EloRegistry — Elo rating registry.
:class:TeamEloRegistry — multi-agent team Elo registry.
:data:DEFAULT_RATING — default rating for unseen agents.
:data:BASELINE_RATINGS — fixed ratings for scripted baseline opponents.
Benchmarks
:class:WFM1Benchmark — WFM-1 zero-shot evaluation.
:class:WFM1BenchmarkConfig — WFM-1 benchmark configuration.
:class:WFM1BenchmarkResult — WFM-1 per-scenario result.
:class:WFM1BenchmarkSummary — WFM-1 aggregate results.
:class:TransferBenchmark — GIS terrain transfer benchmark.
:class:TransferEvalConfig — transfer benchmark configuration.
:class:TransferResult — transfer benchmark per-condition result.
:class:TransferSummary — transfer benchmark aggregate results.
:class:HistoricalBenchmark — historical battle fidelity benchmark.
:class:BenchmarkEntry — per-battle benchmark entry.
:class:BenchmarkSummary — historical benchmark aggregate summary.
TrainingConfig
dataclass
¶
Configuration for a single PPO training run on :class:~envs.battalion_env.BattalionEnv.
All fields are optional; the defaults match configs/default.yaml.
Instances can be passed directly to :func:train or individual fields
can be overridden via **kwargs at the call site.
Examples:
Minimal run with defaults::
from training import train, TrainingConfig
model = train(TrainingConfig(total_timesteps=500_000))
Override specific fields at call time without constructing a config::
model = train(total_timesteps=200_000, n_envs=4, enable_wandb=False)
OpponentPool
¶
Fixed-size pool of frozen PPO policy snapshots.
Snapshots are stored as Stable-Baselines3 .zip files under
pool_dir. The pool keeps at most max_size snapshots; when full,
the oldest snapshot is evicted to make room for the newest one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool_dir
|
str | Path
|
Directory where snapshot |
required |
max_size
|
int
|
Maximum number of snapshots to retain (default 10). |
10
|
Attributes:
| Name | Type | Description |
|---|---|---|
pool_dir |
Path
|
Resolved path of the snapshot directory. |
max_size |
int
|
Maximum number of snapshots retained in the pool. |
size
property
¶
Current number of snapshots in the pool.
snapshot_paths
property
¶
Ordered list of snapshot paths (oldest first, read-only copy).
add(model, version)
¶
Save model as a new snapshot and add it to the pool.
If the pool already contains max_size snapshots the oldest is removed from disk before saving the new one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PPO
|
The current PPO model to snapshot. |
required |
version
|
int
|
Monotonically increasing version number embedded in the snapshot file name for traceability. |
required |
Returns:
| Type | Description |
|---|---|
Path
|
Path of the newly saved snapshot file (including |
sample(rng=None)
¶
Load and return a uniformly sampled snapshot as a PPO model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Optional[Generator]
|
Optional NumPy random generator for reproducible sampling.
When |
None
|
Returns:
| Type | Description |
|---|---|
PPO or None
|
A loaded PPO model, or |
sample_latest()
¶
Load and return the most recently added snapshot.
Returns:
| Type | Description |
|---|---|
PPO or None
|
The latest PPO model, or |
EloRegistry
¶
Persistent Elo rating registry backed by a JSON file.
Stores per-agent ratings and game counts. Scripted baseline opponents
("scripted_l1" … "scripted_l5", "random") have fixed seed
ratings defined in :data:BASELINE_RATINGS and are never modified.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path, None]
|
Path to the JSON file used for persistence. The parent directory
is created automatically on :meth: |
'checkpoints/elo_registry.json'
|
can_save
property
¶
True when the registry has a backing file and can be persisted.
all_ratings()
¶
Return a copy of all stored ratings (excludes pure baselines).
get_game_count(name)
¶
Return the total number of rated games played by name.
get_rating(name)
¶
Return the current Elo rating for name.
Falls back to :data:BASELINE_RATINGS for scripted opponents, then
to :data:DEFAULT_RATING for completely unknown agents.
save()
¶
Persist current ratings and game counts to the JSON file.
Raises:
| Type | Description |
|---|---|
ValueError
|
If the registry was created without a file path ( |
update(agent, opponent, outcome, n_games=1)
¶
Update the Elo rating of agent after a batch of n_games.
The outcome is the average score per game:
1.0— all wins0.5— all draws0.0— all losses
Scripted baseline opponents' ratings are never modified.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
agent
|
str
|
Identifier for the agent whose rating is updated. |
required |
opponent
|
str
|
Identifier of the opponent played against. |
required |
outcome
|
float
|
Average per-game score in |
required |
n_games
|
int
|
Number of games in this batch (used to advance the game-count counter; the K-factor is evaluated at the pre-update count). |
1
|
Returns:
| Type | Description |
|---|---|
float
|
Elo rating delta for agent (positive = rating increased). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If outcome is outside |