Skip to content

Models API

The models package exposes all neural network policy architectures.

Quick-start

from models import BattalionMlpPolicy, MAPPOPolicy, WFM1Policy, ScenarioCard
from stable_baselines3 import PPO
from envs import BattalionEnv

# Standard MLP policy with SB3
env = BattalionEnv()
model = PPO(BattalionMlpPolicy, env)

# Multi-agent MAPPO (obs_dim/state_dim taken from MultiBattalionEnv 2v2 defaults)
from models import MAPPOPolicy
policy = MAPPOPolicy(obs_dim=22, action_dim=3, state_dim=25, n_agents=2)

# WFM-1 foundation model
from models import WFM1Policy, ScenarioCard, ECHELON_BATTALION, TERRAIN_PROCEDURAL
card = ScenarioCard(echelon=ECHELON_BATTALION, terrain_type=TERRAIN_PROCEDURAL)
wfm1 = WFM1Policy()

MLP policy (SB3-compatible)

models.mlp_policy.BattalionMlpPolicy

Bases: ActorCriticPolicy

MLP actor-critic policy for BattalionEnv.

A configurable fully-connected network with separate actor and critic heads, designed for the 12-dimensional observation space of :class:~envs.battalion_env.BattalionEnv.

Architecture (default)::

obs(12) → Linear(128) → Tanh → Linear(128) → Tanh
        ↓                                         ↓
actor head → action(3) + log_std        critic head → value(1)

Parameters:

Name Type Description Default
observation_space Space

Gymnasium observation space.

required
action_space Space

Gymnasium action space.

required
lr_schedule Schedule

Learning-rate schedule passed in by the SB3 algorithm.

required
net_arch Optional[List[int]]

Hidden-layer sizes shared by actor and critic. Defaults to [128, 128].

None
activation_fn Type[Module]

Activation function class applied after each hidden layer. Defaults to :class:torch.nn.Tanh.

Tanh
**kwargs Any

Forwarded to :class:~stable_baselines3.common.policies.ActorCriticPolicy.

{}
Source code in models/mlp_policy.py
class BattalionMlpPolicy(ActorCriticPolicy):
    """MLP actor-critic policy for BattalionEnv.

    A configurable fully-connected network with separate actor and critic
    heads, designed for the 12-dimensional observation space of
    :class:`~envs.battalion_env.BattalionEnv`.

    Architecture (default)::

        obs(12) → Linear(128) → Tanh → Linear(128) → Tanh
                ↓                                         ↓
        actor head → action(3) + log_std        critic head → value(1)

    Parameters
    ----------
    observation_space:
        Gymnasium observation space.
    action_space:
        Gymnasium action space.
    lr_schedule:
        Learning-rate schedule passed in by the SB3 algorithm.
    net_arch:
        Hidden-layer sizes shared by actor and critic.  Defaults to
        ``[128, 128]``.
    activation_fn:
        Activation function class applied after each hidden layer.
        Defaults to :class:`torch.nn.Tanh`.
    **kwargs:
        Forwarded to
        :class:`~stable_baselines3.common.policies.ActorCriticPolicy`.
    """

    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[List[int]] = None,
        activation_fn: Type[nn.Module] = nn.Tanh,
        **kwargs: Any,
    ) -> None:
        if net_arch is None:
            net_arch = [128, 128]
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch=net_arch,
            activation_fn=activation_fn,
            **kwargs,
        )

MAPPO multi-agent policy

models.mappo_policy.MAPPOActor

Bases: Module

Shared actor network for homogeneous MAPPO agents.

Takes a local observation vector and produces a diagonal Gaussian action distribution. The log_std parameters are learned but shared across the batch (not conditioned on the observation).

Parameters:

Name Type Description Default
obs_dim int

Dimensionality of the per-agent local observation.

required
action_dim int

Dimensionality of the continuous action space.

required
hidden_sizes Tuple[int, ...]

Sizes of the hidden layers in the shared trunk MLP.

(128, 64)
Source code in models/mappo_policy.py
class MAPPOActor(nn.Module):
    """Shared actor network for homogeneous MAPPO agents.

    Takes a **local observation** vector and produces a diagonal Gaussian
    action distribution.  The ``log_std`` parameters are learned but shared
    across the batch (not conditioned on the observation).

    Parameters
    ----------
    obs_dim:
        Dimensionality of the per-agent local observation.
    action_dim:
        Dimensionality of the continuous action space.
    hidden_sizes:
        Sizes of the hidden layers in the shared trunk MLP.
    """

    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        hidden_sizes: Tuple[int, ...] = (128, 64),
    ) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim

        trunk, out_dim = _build_mlp(obs_dim, hidden_sizes)
        self.trunk = trunk
        self.action_mean = nn.Linear(out_dim, action_dim)
        # Learnable log standard deviation (not observation-conditioned)
        self.log_std = nn.Parameter(torch.zeros(action_dim))

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return ``(action_mean, action_std)`` tensors.

        Parameters
        ----------
        obs:
            Local observation of shape ``(..., obs_dim)``.

        Returns
        -------
        mean : torch.Tensor — shape ``(..., action_dim)``
        std  : torch.Tensor — shape ``(..., action_dim)``
        """
        h = self.trunk(obs)
        mean = self.action_mean(h)
        std = self.log_std.exp().expand_as(mean)
        return mean, std

    def get_distribution(self, obs: torch.Tensor) -> Normal:
        """Return a :class:`~torch.distributions.Normal` over actions."""
        mean, std = self.forward(obs)
        return Normal(mean, std)

    def evaluate_actions(
        self, obs: torch.Tensor, actions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Evaluate log-probabilities and entropy for given (obs, action) pairs.

        Parameters
        ----------
        obs:
            Local observations of shape ``(batch, obs_dim)``.
        actions:
            Actions of shape ``(batch, action_dim)``.

        Returns
        -------
        log_probs : torch.Tensor — shape ``(batch,)``
        entropy   : torch.Tensor — shape ``(batch,)``
        """
        dist = self.get_distribution(obs)
        log_probs = dist.log_prob(actions).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1)
        return log_probs, entropy

evaluate_actions(obs, actions)

Evaluate log-probabilities and entropy for given (obs, action) pairs.

Parameters:

Name Type Description Default
obs Tensor

Local observations of shape (batch, obs_dim).

required
actions Tensor

Actions of shape (batch, action_dim).

required

Returns:

Name Type Description
log_probs torch.Tensor — shape ``(batch,)``
entropy torch.Tensor — shape ``(batch,)``
Source code in models/mappo_policy.py
def evaluate_actions(
    self, obs: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Evaluate log-probabilities and entropy for given (obs, action) pairs.

    Parameters
    ----------
    obs:
        Local observations of shape ``(batch, obs_dim)``.
    actions:
        Actions of shape ``(batch, action_dim)``.

    Returns
    -------
    log_probs : torch.Tensor — shape ``(batch,)``
    entropy   : torch.Tensor — shape ``(batch,)``
    """
    dist = self.get_distribution(obs)
    log_probs = dist.log_prob(actions).sum(dim=-1)
    entropy = dist.entropy().sum(dim=-1)
    return log_probs, entropy

forward(obs)

Return (action_mean, action_std) tensors.

Parameters:

Name Type Description Default
obs Tensor

Local observation of shape (..., obs_dim).

required

Returns:

Name Type Description
mean torch.Tensor — shape ``(..., action_dim)``
std torch.Tensor — shape ``(..., action_dim)``
Source code in models/mappo_policy.py
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Return ``(action_mean, action_std)`` tensors.

    Parameters
    ----------
    obs:
        Local observation of shape ``(..., obs_dim)``.

    Returns
    -------
    mean : torch.Tensor — shape ``(..., action_dim)``
    std  : torch.Tensor — shape ``(..., action_dim)``
    """
    h = self.trunk(obs)
    mean = self.action_mean(h)
    std = self.log_std.exp().expand_as(mean)
    return mean, std

get_distribution(obs)

Return a :class:~torch.distributions.Normal over actions.

Source code in models/mappo_policy.py
def get_distribution(self, obs: torch.Tensor) -> Normal:
    """Return a :class:`~torch.distributions.Normal` over actions."""
    mean, std = self.forward(obs)
    return Normal(mean, std)

models.mappo_policy.MAPPOCritic

Bases: Module

Centralized critic conditioned on the global state tensor.

Receives the global state (all agents' positions, headings, strengths and morale) and produces a scalar value estimate used for advantage computation in MAPPO.

Parameters:

Name Type Description Default
state_dim int

Dimensionality of the global state vector (output of :meth:~envs.multi_battalion_env.MultiBattalionEnv.state).

required
hidden_sizes Tuple[int, ...]

Sizes of the hidden layers in the critic MLP.

(128, 64)
Source code in models/mappo_policy.py
class MAPPOCritic(nn.Module):
    """Centralized critic conditioned on the global state tensor.

    Receives the **global state** (all agents' positions, headings,
    strengths and morale) and produces a scalar value estimate used for
    advantage computation in MAPPO.

    Parameters
    ----------
    state_dim:
        Dimensionality of the global state vector (output of
        :meth:`~envs.multi_battalion_env.MultiBattalionEnv.state`).
    hidden_sizes:
        Sizes of the hidden layers in the critic MLP.
    """

    def __init__(
        self,
        state_dim: int,
        hidden_sizes: Tuple[int, ...] = (128, 64),
    ) -> None:
        super().__init__()
        self.state_dim = state_dim

        trunk, out_dim = _build_mlp(state_dim, hidden_sizes)
        self.trunk = trunk
        self.value_head = nn.Linear(out_dim, 1)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Return value estimates.

        Parameters
        ----------
        state:
            Global state tensor of shape ``(..., state_dim)``.

        Returns
        -------
        values : torch.Tensor — shape ``(...)``
        """
        return self.value_head(self.trunk(state)).squeeze(-1)

forward(state)

Return value estimates.

Parameters:

Name Type Description Default
state Tensor

Global state tensor of shape (..., state_dim).

required

Returns:

Name Type Description
values torch.Tensor — shape ``(...)``
Source code in models/mappo_policy.py
def forward(self, state: torch.Tensor) -> torch.Tensor:
    """Return value estimates.

    Parameters
    ----------
    state:
        Global state tensor of shape ``(..., state_dim)``.

    Returns
    -------
    values : torch.Tensor — shape ``(...)``
    """
    return self.value_head(self.trunk(state)).squeeze(-1)

models.mappo_policy.MAPPOPolicy

Bases: Module

MAPPO policy: shared actor(s) plus a centralized critic.

Supports two parameter-sharing modes for the actor:

  • share_parameters=True (default) — all n_agents agents use the same :class:MAPPOActor weights. Memory scales as O(actor_params + critic_params) instead of O(n_agents * actor_params).
  • share_parameters=False — each agent gets its own :class:MAPPOActor; useful for ablation studies.

In both cases there is a single :class:MAPPOCritic that all agents share.

Parameters:

Name Type Description Default
obs_dim int

Per-agent local observation dimensionality.

required
action_dim int

Per-agent action dimensionality (continuous).

required
state_dim int

Global state dimensionality.

required
n_agents int

Number of controlled agents (used only when share_parameters=False to build separate actor heads).

1
share_parameters bool

Whether all agents share one actor. Defaults to True.

True
actor_hidden_sizes Tuple[int, ...]

Hidden layer sizes for the actor trunk.

(128, 64)
critic_hidden_sizes Tuple[int, ...]

Hidden layer sizes for the critic trunk.

(128, 64)
Source code in models/mappo_policy.py
class MAPPOPolicy(nn.Module):
    """MAPPO policy: shared actor(s) plus a centralized critic.

    Supports two parameter-sharing modes for the actor:

    * ``share_parameters=True`` (default) — all *n_agents* agents use the
      **same** :class:`MAPPOActor` weights.  Memory scales as
      O(actor_params + critic_params) instead of O(n_agents * actor_params).
    * ``share_parameters=False`` — each agent gets its own
      :class:`MAPPOActor`; useful for ablation studies.

    In both cases there is a **single** :class:`MAPPOCritic` that all
    agents share.

    Parameters
    ----------
    obs_dim:
        Per-agent local observation dimensionality.
    action_dim:
        Per-agent action dimensionality (continuous).
    state_dim:
        Global state dimensionality.
    n_agents:
        Number of controlled agents (used only when
        ``share_parameters=False`` to build separate actor heads).
    share_parameters:
        Whether all agents share one actor.  Defaults to ``True``.
    actor_hidden_sizes:
        Hidden layer sizes for the actor trunk.
    critic_hidden_sizes:
        Hidden layer sizes for the critic trunk.
    """

    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        state_dim: int,
        n_agents: int = 1,
        share_parameters: bool = True,
        actor_hidden_sizes: Tuple[int, ...] = (128, 64),
        critic_hidden_sizes: Tuple[int, ...] = (128, 64),
    ) -> None:
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.n_agents = n_agents
        self.share_parameters = share_parameters
        self.actor_hidden_sizes: Tuple[int, ...] = tuple(actor_hidden_sizes)
        self.critic_hidden_sizes: Tuple[int, ...] = tuple(critic_hidden_sizes)

        if share_parameters:
            self.actor: nn.Module = MAPPOActor(obs_dim, action_dim, actor_hidden_sizes)
        else:
            self.actors: nn.ModuleList = nn.ModuleList(
                [MAPPOActor(obs_dim, action_dim, actor_hidden_sizes) for _ in range(n_agents)]
            )

        self.critic = MAPPOCritic(state_dim, critic_hidden_sizes)

    # ------------------------------------------------------------------
    # Actor helpers
    # ------------------------------------------------------------------

    def get_actor(self, agent_idx: int = 0) -> MAPPOActor:
        """Return the actor for *agent_idx*.

        When ``share_parameters=True`` the same actor is returned
        regardless of *agent_idx*.
        """
        if self.share_parameters:
            return self.actor  # type: ignore[return-value]
        return self.actors[agent_idx]  # type: ignore[index]

    @torch.no_grad()
    def act(
        self,
        obs: torch.Tensor,
        agent_idx: int = 0,
        deterministic: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample actions for a batch of observations.

        A batch dimension is added internally when *obs* is 1-D so that
        the return shapes are **always** ``(batch, action_dim)`` and
        ``(batch,)`` regardless of whether the caller passes a single
        observation or a batch.

        Parameters
        ----------
        obs:
            Local observations of shape ``(batch, obs_dim)`` or
            ``(obs_dim,)`` for a single observation.
        agent_idx:
            Index of the agent whose actor should be used.  Ignored when
            ``share_parameters=True``.
        deterministic:
            When ``True`` returns the distribution mean instead of
            sampling.

        Returns
        -------
        actions  : torch.Tensor — shape ``(batch, action_dim)``
        log_probs: torch.Tensor — shape ``(batch,)``
        """
        squeezed = obs.dim() == 1
        if squeezed:
            obs = obs.unsqueeze(0)
        actor = self.get_actor(agent_idx)
        dist = actor.get_distribution(obs)
        actions = dist.mean if deterministic else dist.rsample()
        log_probs = dist.log_prob(actions).sum(dim=-1)
        return actions, log_probs

    # ------------------------------------------------------------------
    # Critic helpers
    # ------------------------------------------------------------------

    @torch.no_grad()
    def get_value(self, state: torch.Tensor) -> torch.Tensor:
        """Return value estimate(s) for the given global state(s).

        Parameters
        ----------
        state:
            Global state tensor of shape ``(state_dim,)`` or
            ``(batch, state_dim)``.

        Returns
        -------
        values : torch.Tensor — scalar or shape ``(batch,)``
        """
        return self.critic(state)

    # ------------------------------------------------------------------
    # Evaluation (with gradients, used in the update step)
    # ------------------------------------------------------------------

    def evaluate_actions(
        self,
        obs: torch.Tensor,
        actions: torch.Tensor,
        state: torch.Tensor,
        agent_idx: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Evaluate actions under the current policy for the PPO update.

        Parameters
        ----------
        obs:
            Local observations of shape ``(batch, obs_dim)``.
        actions:
            Actions of shape ``(batch, action_dim)``.
        state:
            Global states of shape ``(batch, state_dim)``.
        agent_idx:
            Actor index (ignored when sharing parameters).

        Returns
        -------
        log_probs : torch.Tensor — shape ``(batch,)``
        entropy   : torch.Tensor — shape ``(batch,)``
        values    : torch.Tensor — shape ``(batch,)``
        """
        actor = self.get_actor(agent_idx)
        log_probs, entropy = actor.evaluate_actions(obs, actions)
        values = self.critic(state)
        return log_probs, entropy, values

    # ------------------------------------------------------------------
    # Utility
    # ------------------------------------------------------------------

    def parameter_count(self) -> dict[str, int]:
        """Return a dict with actor and critic parameter counts."""
        if self.share_parameters:
            actor_params = sum(p.numel() for p in self.actor.parameters())  # type: ignore[attr-defined]
        else:
            actor_params = sum(
                p.numel() for actor in self.actors for p in actor.parameters()  # type: ignore[attr-defined]
            )
        critic_params = sum(p.numel() for p in self.critic.parameters())
        return {
            "actor": actor_params,
            "critic": critic_params,
            "total": actor_params + critic_params,
        }

act(obs, agent_idx=0, deterministic=False)

Sample actions for a batch of observations.

A batch dimension is added internally when obs is 1-D so that the return shapes are always (batch, action_dim) and (batch,) regardless of whether the caller passes a single observation or a batch.

Parameters:

Name Type Description Default
obs Tensor

Local observations of shape (batch, obs_dim) or (obs_dim,) for a single observation.

required
agent_idx int

Index of the agent whose actor should be used. Ignored when share_parameters=True.

0
deterministic bool

When True returns the distribution mean instead of sampling.

False

Returns:

Name Type Description
actions torch.Tensor — shape ``(batch, action_dim)``
log_probs torch.Tensor — shape ``(batch,)``
Source code in models/mappo_policy.py
@torch.no_grad()
def act(
    self,
    obs: torch.Tensor,
    agent_idx: int = 0,
    deterministic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Sample actions for a batch of observations.

    A batch dimension is added internally when *obs* is 1-D so that
    the return shapes are **always** ``(batch, action_dim)`` and
    ``(batch,)`` regardless of whether the caller passes a single
    observation or a batch.

    Parameters
    ----------
    obs:
        Local observations of shape ``(batch, obs_dim)`` or
        ``(obs_dim,)`` for a single observation.
    agent_idx:
        Index of the agent whose actor should be used.  Ignored when
        ``share_parameters=True``.
    deterministic:
        When ``True`` returns the distribution mean instead of
        sampling.

    Returns
    -------
    actions  : torch.Tensor — shape ``(batch, action_dim)``
    log_probs: torch.Tensor — shape ``(batch,)``
    """
    squeezed = obs.dim() == 1
    if squeezed:
        obs = obs.unsqueeze(0)
    actor = self.get_actor(agent_idx)
    dist = actor.get_distribution(obs)
    actions = dist.mean if deterministic else dist.rsample()
    log_probs = dist.log_prob(actions).sum(dim=-1)
    return actions, log_probs

evaluate_actions(obs, actions, state, agent_idx=0)

Evaluate actions under the current policy for the PPO update.

Parameters:

Name Type Description Default
obs Tensor

Local observations of shape (batch, obs_dim).

required
actions Tensor

Actions of shape (batch, action_dim).

required
state Tensor

Global states of shape (batch, state_dim).

required
agent_idx int

Actor index (ignored when sharing parameters).

0

Returns:

Name Type Description
log_probs torch.Tensor — shape ``(batch,)``
entropy torch.Tensor — shape ``(batch,)``
values torch.Tensor — shape ``(batch,)``
Source code in models/mappo_policy.py
def evaluate_actions(
    self,
    obs: torch.Tensor,
    actions: torch.Tensor,
    state: torch.Tensor,
    agent_idx: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Evaluate actions under the current policy for the PPO update.

    Parameters
    ----------
    obs:
        Local observations of shape ``(batch, obs_dim)``.
    actions:
        Actions of shape ``(batch, action_dim)``.
    state:
        Global states of shape ``(batch, state_dim)``.
    agent_idx:
        Actor index (ignored when sharing parameters).

    Returns
    -------
    log_probs : torch.Tensor — shape ``(batch,)``
    entropy   : torch.Tensor — shape ``(batch,)``
    values    : torch.Tensor — shape ``(batch,)``
    """
    actor = self.get_actor(agent_idx)
    log_probs, entropy = actor.evaluate_actions(obs, actions)
    values = self.critic(state)
    return log_probs, entropy, values

get_actor(agent_idx=0)

Return the actor for agent_idx.

When share_parameters=True the same actor is returned regardless of agent_idx.

Source code in models/mappo_policy.py
def get_actor(self, agent_idx: int = 0) -> MAPPOActor:
    """Return the actor for *agent_idx*.

    When ``share_parameters=True`` the same actor is returned
    regardless of *agent_idx*.
    """
    if self.share_parameters:
        return self.actor  # type: ignore[return-value]
    return self.actors[agent_idx]  # type: ignore[index]

get_value(state)

Return value estimate(s) for the given global state(s).

Parameters:

Name Type Description Default
state Tensor

Global state tensor of shape (state_dim,) or (batch, state_dim).

required

Returns:

Name Type Description
values torch.Tensor — scalar or shape ``(batch,)``
Source code in models/mappo_policy.py
@torch.no_grad()
def get_value(self, state: torch.Tensor) -> torch.Tensor:
    """Return value estimate(s) for the given global state(s).

    Parameters
    ----------
    state:
        Global state tensor of shape ``(state_dim,)`` or
        ``(batch, state_dim)``.

    Returns
    -------
    values : torch.Tensor — scalar or shape ``(batch,)``
    """
    return self.critic(state)

parameter_count()

Return a dict with actor and critic parameter counts.

Source code in models/mappo_policy.py
def parameter_count(self) -> dict[str, int]:
    """Return a dict with actor and critic parameter counts."""
    if self.share_parameters:
        actor_params = sum(p.numel() for p in self.actor.parameters())  # type: ignore[attr-defined]
    else:
        actor_params = sum(
            p.numel() for actor in self.actors for p in actor.parameters()  # type: ignore[attr-defined]
        )
    critic_params = sum(p.numel() for p in self.critic.parameters())
    return {
        "actor": actor_params,
        "critic": critic_params,
        "total": actor_params + critic_params,
    }

Entity encoder (transformer)

models.entity_encoder.EntityEncoder

Bases: Module

Multi-head self-attention encoder over a variable-length entity sequence.

Architecture::

entity tokens (B, N, token_dim)
      
token_embed: Linear(token_dim  d_model)
       + SpatialPositionalEncoding(x, y) [optional]
      
TransformerEncoder (n_layers × TransformerEncoderLayer)
   └── MultiheadAttention (n_heads, d_model)
   └── FFN (dim_feedforward = 4 * d_model)
   └── LayerNorm + residual
      
mean-pool over non-padded entities    (B, d_model)
      
output projection: Linear(d_model  d_model)  [identity-initialised]

Padding is handled via src_key_padding_mask: a boolean tensor of shape (B, N) where True marks padded (ignored) positions.

Parameters:

Name Type Description Default
token_dim int

Dimensionality of each input entity token. Defaults to ENTITY_TOKEN_DIM (16).

ENTITY_TOKEN_DIM
d_model int

Internal transformer dimension.

64
n_heads int

Number of attention heads. Must evenly divide d_model.

4
n_layers int

Number of transformer encoder layers.

2
dim_feedforward Optional[int]

Feed-forward sublayer width. Defaults to 4 * d_model.

None
dropout float

Dropout probability inside the transformer.

0.0
use_spatial_pe bool

When True, add a 2-D Fourier positional encoding derived from the position field (x, y) of each token.

True
n_freq_bands int

Number of Fourier frequency bands used by :class:SpatialPositionalEncoding when use_spatial_pe=True.

8
Source code in models/entity_encoder.py
class EntityEncoder(nn.Module):
    """Multi-head self-attention encoder over a variable-length entity sequence.

    Architecture::

        entity tokens (B, N, token_dim)

        token_embed: Linear(token_dim → d_model)
              │ + SpatialPositionalEncoding(x, y) [optional]

        TransformerEncoder (n_layers × TransformerEncoderLayer)
           └── MultiheadAttention (n_heads, d_model)
           └── FFN (dim_feedforward = 4 * d_model)
           └── LayerNorm + residual

        mean-pool over non-padded entities  →  (B, d_model)

        output projection: Linear(d_model → d_model)  [identity-initialised]

    Padding is handled via ``src_key_padding_mask``: a boolean tensor of shape
    ``(B, N)`` where ``True`` marks *padded* (ignored) positions.

    Parameters
    ----------
    token_dim:
        Dimensionality of each input entity token.  Defaults to
        ``ENTITY_TOKEN_DIM`` (16).
    d_model:
        Internal transformer dimension.
    n_heads:
        Number of attention heads.  Must evenly divide ``d_model``.
    n_layers:
        Number of transformer encoder layers.
    dim_feedforward:
        Feed-forward sublayer width.  Defaults to ``4 * d_model``.
    dropout:
        Dropout probability inside the transformer.
    use_spatial_pe:
        When ``True``, add a 2-D Fourier positional encoding derived from
        the position field ``(x, y)`` of each token.
    n_freq_bands:
        Number of Fourier frequency bands used by
        :class:`SpatialPositionalEncoding` when ``use_spatial_pe=True``.
    """

    def __init__(
        self,
        token_dim: int = ENTITY_TOKEN_DIM,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        dim_feedforward: Optional[int] = None,
        dropout: float = 0.0,
        use_spatial_pe: bool = True,
        n_freq_bands: int = 8,
    ) -> None:
        super().__init__()
        self.token_dim = token_dim
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.use_spatial_pe = use_spatial_pe

        if dim_feedforward is None:
            dim_feedforward = _DEFAULT_FFN_SCALE * d_model

        # Project raw token features → d_model
        self.token_embed = nn.Linear(token_dim, d_model)

        # Optional 2-D Fourier positional encoding
        if use_spatial_pe:
            self.spatial_pe = SpatialPositionalEncoding(d_model, n_freqs=n_freq_bands)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,  # (B, N, d_model) convention
            norm_first=True,   # pre-norm for training stability
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=n_layers,
            enable_nested_tensor=False,  # required when norm_first=True
        )

        # Output projection — initialised as identity to preserve activations
        # at the start of training (weight = I, bias = 0).
        self.out_proj = nn.Linear(d_model, d_model)
        with torch.no_grad():
            self.out_proj.weight.copy_(torch.eye(d_model))
            if self.out_proj.bias is not None:
                self.out_proj.bias.zero_()

    @property
    def output_dim(self) -> int:
        """Dimensionality of the pooled output vector."""
        return self.d_model

    def forward(
        self,
        tokens: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None,
        return_attention: bool = False,
    ) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
        """Encode a batch of entity sequences.

        Parameters
        ----------
        tokens:
            Entity token tensor of shape ``(B, N, token_dim)``.
        pad_mask:
            Boolean mask of shape ``(B, N)``.  ``True`` marks padded
            (ignored) positions.  If ``None``, all positions are attended.
        return_attention:
            When ``True``, also return the averaged attention weights from
            the *last* transformer layer as a second return value of shape
            ``(B, N, N)``.

        Returns
        -------
        encoding : torch.Tensor — shape ``(B, d_model)``
            Mean-pooled sequence encoding.
        attn_weights : torch.Tensor — shape ``(B, N, N)``
            Averaged attention weights from the last layer.
            Only returned when ``return_attention=True``.
        """
        B, N, _ = tokens.shape

        # Token embedding
        x = self.token_embed(tokens)  # (B, N, d_model)

        # Optional 2-D Fourier positional encoding from (x, y) position fields
        if self.use_spatial_pe:
            xy = tokens[..., _SLICE_POSITION]  # (B, N, 2)
            x = x + self.spatial_pe(xy)

        if return_attention:
            if len(self.transformer.layers) == 0:
                # No transformer layers: fall back to the standard path and
                # return a zero attention matrix as a sensible default.
                x = self.transformer(x, src_key_padding_mask=pad_mask)
                attn_weights = torch.zeros(B, N, N, device=x.device, dtype=x.dtype)
            else:
                # Run layers manually to extract attention from the last layer
                for i, layer in enumerate(self.transformer.layers):
                    if i < len(self.transformer.layers) - 1:
                        x = layer(x, src_key_padding_mask=pad_mask)
                    else:
                        # Last layer: extract attention weights
                        x_norm = layer.norm1(x) if layer.norm_first else x
                        attn_out, attn_weights = layer.self_attn(
                            x_norm, x_norm, x_norm,
                            key_padding_mask=pad_mask,
                            need_weights=True,
                            average_attn_weights=True,
                        )
                        # Complete the residual path
                        if layer.norm_first:
                            x = x + layer.dropout1(attn_out)
                            x = x + layer.dropout2(
                                layer.linear2(
                                    layer.dropout(layer.activation(layer.linear1(layer.norm2(x))))
                                )
                            )
                        else:
                            x = layer.norm1(x + layer.dropout1(attn_out))
                            x = layer.norm2(
                                x + layer.dropout2(
                                    layer.linear2(layer.dropout(layer.activation(layer.linear1(x))))
                                )
                            )
        else:
            x = self.transformer(x, src_key_padding_mask=pad_mask)
            attn_weights = None

        # Mean-pool over non-padded positions
        if pad_mask is not None:
            # Invert mask: True = keep
            keep = ~pad_mask  # (B, N)
            # Avoid division by zero if all positions are padded
            n_valid = keep.float().sum(dim=1, keepdim=True).clamp(min=1.0)  # (B, 1); clamp avoids ÷0 if all padded
            x = (x * keep.unsqueeze(-1).float()).sum(dim=1) / n_valid  # (B, d_model)
        else:
            x = x.mean(dim=1)  # (B, d_model)

        encoding = self.out_proj(x)  # (B, d_model)

        if return_attention:
            return encoding, attn_weights  # type: ignore[return-value]
        return encoding

    @staticmethod
    def make_padding_mask(n_valid: torch.Tensor, max_n: int) -> torch.Tensor:
        """Create a padding mask from per-sample entity counts.

        Parameters
        ----------
        n_valid:
            Integer tensor of shape ``(B,)`` with the number of real
            (non-padded) entities in each sample.
        max_n:
            Total number of positions in the padded sequence.

        Returns
        -------
        pad_mask : torch.BoolTensor — shape ``(B, max_n)``
            ``True`` where the position is padded (ignored).
        """
        B = n_valid.size(0)
        idx = torch.arange(max_n, device=n_valid.device).unsqueeze(0)  # (1, max_n)
        return idx >= n_valid.unsqueeze(1)  # (B, max_n)

output_dim property

Dimensionality of the pooled output vector.

forward(tokens, pad_mask=None, return_attention=False)

Encode a batch of entity sequences.

Parameters:

Name Type Description Default
tokens Tensor

Entity token tensor of shape (B, N, token_dim).

required
pad_mask Optional[Tensor]

Boolean mask of shape (B, N). True marks padded (ignored) positions. If None, all positions are attended.

None
return_attention bool

When True, also return the averaged attention weights from the last transformer layer as a second return value of shape (B, N, N).

False

Returns:

Name Type Description
encoding torch.Tensor — shape ``(B, d_model)``

Mean-pooled sequence encoding.

attn_weights torch.Tensor — shape ``(B, N, N)``

Averaged attention weights from the last layer. Only returned when return_attention=True.

Source code in models/entity_encoder.py
def forward(
    self,
    tokens: torch.Tensor,
    pad_mask: Optional[torch.Tensor] = None,
    return_attention: bool = False,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
    """Encode a batch of entity sequences.

    Parameters
    ----------
    tokens:
        Entity token tensor of shape ``(B, N, token_dim)``.
    pad_mask:
        Boolean mask of shape ``(B, N)``.  ``True`` marks padded
        (ignored) positions.  If ``None``, all positions are attended.
    return_attention:
        When ``True``, also return the averaged attention weights from
        the *last* transformer layer as a second return value of shape
        ``(B, N, N)``.

    Returns
    -------
    encoding : torch.Tensor — shape ``(B, d_model)``
        Mean-pooled sequence encoding.
    attn_weights : torch.Tensor — shape ``(B, N, N)``
        Averaged attention weights from the last layer.
        Only returned when ``return_attention=True``.
    """
    B, N, _ = tokens.shape

    # Token embedding
    x = self.token_embed(tokens)  # (B, N, d_model)

    # Optional 2-D Fourier positional encoding from (x, y) position fields
    if self.use_spatial_pe:
        xy = tokens[..., _SLICE_POSITION]  # (B, N, 2)
        x = x + self.spatial_pe(xy)

    if return_attention:
        if len(self.transformer.layers) == 0:
            # No transformer layers: fall back to the standard path and
            # return a zero attention matrix as a sensible default.
            x = self.transformer(x, src_key_padding_mask=pad_mask)
            attn_weights = torch.zeros(B, N, N, device=x.device, dtype=x.dtype)
        else:
            # Run layers manually to extract attention from the last layer
            for i, layer in enumerate(self.transformer.layers):
                if i < len(self.transformer.layers) - 1:
                    x = layer(x, src_key_padding_mask=pad_mask)
                else:
                    # Last layer: extract attention weights
                    x_norm = layer.norm1(x) if layer.norm_first else x
                    attn_out, attn_weights = layer.self_attn(
                        x_norm, x_norm, x_norm,
                        key_padding_mask=pad_mask,
                        need_weights=True,
                        average_attn_weights=True,
                    )
                    # Complete the residual path
                    if layer.norm_first:
                        x = x + layer.dropout1(attn_out)
                        x = x + layer.dropout2(
                            layer.linear2(
                                layer.dropout(layer.activation(layer.linear1(layer.norm2(x))))
                            )
                        )
                    else:
                        x = layer.norm1(x + layer.dropout1(attn_out))
                        x = layer.norm2(
                            x + layer.dropout2(
                                layer.linear2(layer.dropout(layer.activation(layer.linear1(x))))
                            )
                        )
    else:
        x = self.transformer(x, src_key_padding_mask=pad_mask)
        attn_weights = None

    # Mean-pool over non-padded positions
    if pad_mask is not None:
        # Invert mask: True = keep
        keep = ~pad_mask  # (B, N)
        # Avoid division by zero if all positions are padded
        n_valid = keep.float().sum(dim=1, keepdim=True).clamp(min=1.0)  # (B, 1); clamp avoids ÷0 if all padded
        x = (x * keep.unsqueeze(-1).float()).sum(dim=1) / n_valid  # (B, d_model)
    else:
        x = x.mean(dim=1)  # (B, d_model)

    encoding = self.out_proj(x)  # (B, d_model)

    if return_attention:
        return encoding, attn_weights  # type: ignore[return-value]
    return encoding

make_padding_mask(n_valid, max_n) staticmethod

Create a padding mask from per-sample entity counts.

Parameters:

Name Type Description Default
n_valid Tensor

Integer tensor of shape (B,) with the number of real (non-padded) entities in each sample.

required
max_n int

Total number of positions in the padded sequence.

required

Returns:

Name Type Description
pad_mask torch.BoolTensor — shape ``(B, max_n)``

True where the position is padded (ignored).

Source code in models/entity_encoder.py
@staticmethod
def make_padding_mask(n_valid: torch.Tensor, max_n: int) -> torch.Tensor:
    """Create a padding mask from per-sample entity counts.

    Parameters
    ----------
    n_valid:
        Integer tensor of shape ``(B,)`` with the number of real
        (non-padded) entities in each sample.
    max_n:
        Total number of positions in the padded sequence.

    Returns
    -------
    pad_mask : torch.BoolTensor — shape ``(B, max_n)``
        ``True`` where the position is padded (ignored).
    """
    B = n_valid.size(0)
    idx = torch.arange(max_n, device=n_valid.device).unsqueeze(0)  # (1, max_n)
    return idx >= n_valid.unsqueeze(1)  # (B, max_n)

models.entity_encoder.EntityActorCriticPolicy

Bases: Module

Actor-critic policy that uses an :class:EntityEncoder as the backbone.

Both the actor and the centralized critic share a single entity encoder (weight-sharing is optional and controlled by shared_encoder).

Actor head

Takes the pooled entity encoding (B, d_model) and produces a diagonal Gaussian action distribution.

Critic head

Takes the pooled encoding of the global entity sequence (all units from both teams) and produces a scalar value estimate.

Parameters:

Name Type Description Default
token_dim int

Entity token dimensionality. Defaults to ENTITY_TOKEN_DIM.

ENTITY_TOKEN_DIM
action_dim int

Continuous action space dimensionality.

3
d_model int

Transformer internal dimension.

64
n_heads int

Number of attention heads.

4
n_layers int

Number of transformer encoder layers.

2
actor_hidden_sizes Tuple[int, ...]

MLP hidden sizes applied on top of the entity encoding for the actor.

(128, 64)
critic_hidden_sizes Tuple[int, ...]

MLP hidden sizes applied on top of the entity encoding for the critic.

(128, 64)
shared_encoder bool

When True (default), actor and critic share the same :class:EntityEncoder weights. When False they each have an independent encoder.

True
dropout float

Dropout probability in the transformer layers.

0.0
use_spatial_pe bool

Enable 2-D Fourier positional encoding.

True
Source code in models/entity_encoder.py
class EntityActorCriticPolicy(nn.Module):
    """Actor-critic policy that uses an :class:`EntityEncoder` as the backbone.

    Both the actor and the centralized critic share a single entity encoder
    (weight-sharing is optional and controlled by ``shared_encoder``).

    Actor head
    ----------
    Takes the pooled entity encoding ``(B, d_model)`` and produces a diagonal
    Gaussian action distribution.

    Critic head
    -----------
    Takes the pooled encoding of the **global** entity sequence (all units
    from both teams) and produces a scalar value estimate.

    Parameters
    ----------
    token_dim:
        Entity token dimensionality.  Defaults to ``ENTITY_TOKEN_DIM``.
    action_dim:
        Continuous action space dimensionality.
    d_model:
        Transformer internal dimension.
    n_heads:
        Number of attention heads.
    n_layers:
        Number of transformer encoder layers.
    actor_hidden_sizes:
        MLP hidden sizes applied on top of the entity encoding for the actor.
    critic_hidden_sizes:
        MLP hidden sizes applied on top of the entity encoding for the critic.
    shared_encoder:
        When ``True`` (default), actor and critic share the same
        :class:`EntityEncoder` weights.  When ``False`` they each have an
        independent encoder.
    dropout:
        Dropout probability in the transformer layers.
    use_spatial_pe:
        Enable 2-D Fourier positional encoding.
    """

    def __init__(
        self,
        token_dim: int = ENTITY_TOKEN_DIM,
        action_dim: int = 3,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        actor_hidden_sizes: Tuple[int, ...] = (128, 64),
        critic_hidden_sizes: Tuple[int, ...] = (128, 64),
        shared_encoder: bool = True,
        dropout: float = 0.0,
        use_spatial_pe: bool = True,
    ) -> None:
        super().__init__()
        self.token_dim = token_dim
        self.action_dim = action_dim
        self.d_model = d_model
        self.shared_encoder = shared_encoder

        # Entity encoders
        self.actor_encoder = EntityEncoder(
            token_dim=token_dim,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            dropout=dropout,
            use_spatial_pe=use_spatial_pe,
        )
        if shared_encoder:
            self.critic_encoder = self.actor_encoder
        else:
            self.critic_encoder = EntityEncoder(
                token_dim=token_dim,
                d_model=d_model,
                n_heads=n_heads,
                n_layers=n_layers,
                dropout=dropout,
                use_spatial_pe=use_spatial_pe,
            )

        # Actor head: encoding → action mean
        self.actor_head = _build_mlp(d_model, actor_hidden_sizes, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))

        # Critic head: encoding → scalar value
        self.critic_head = _build_mlp(d_model, critic_hidden_sizes, 1)

    # ------------------------------------------------------------------
    # Actor helpers
    # ------------------------------------------------------------------

    def get_distribution(
        self,
        tokens: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None,
    ) -> Normal:
        """Compute the action distribution for a batch of entity sequences.

        Parameters
        ----------
        tokens:
            Shape ``(B, N, token_dim)``.
        pad_mask:
            Boolean padding mask of shape ``(B, N)``.

        Returns
        -------
        dist : :class:`~torch.distributions.Normal`
        """
        enc = self.actor_encoder(tokens, pad_mask)  # (B, d_model)
        mean = self.actor_head(enc)  # (B, action_dim)
        std = self.log_std.exp().expand_as(mean)
        return Normal(mean, std)

    @torch.no_grad()
    def act(
        self,
        tokens: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None,
        deterministic: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample (or select deterministically) actions for a batch.

        Parameters
        ----------
        tokens:
            Shape ``(B, N, token_dim)`` or ``(N, token_dim)`` for a single
            sample (batch dimension will be added automatically).
        pad_mask:
            Padding mask of shape ``(B, N)`` or ``(N,)`` for single sample.
        deterministic:
            Return the distribution mean instead of sampling.

        Returns
        -------
        actions   : torch.Tensor — shape ``(B, action_dim)``
        log_probs : torch.Tensor — shape ``(B,)``
        """
        squeezed = tokens.dim() == 2
        if squeezed:
            tokens = tokens.unsqueeze(0)
            if pad_mask is not None:
                pad_mask = pad_mask.unsqueeze(0)
        dist = self.get_distribution(tokens, pad_mask)
        actions = dist.mean if deterministic else dist.rsample()
        log_probs = dist.log_prob(actions).sum(dim=-1)
        return actions, log_probs

    def evaluate_actions(
        self,
        tokens: torch.Tensor,
        actions: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Evaluate log-probs and entropy for given token sequences and actions.

        Parameters
        ----------
        tokens:
            Shape ``(B, N, token_dim)``.
        actions:
            Shape ``(B, action_dim)``.
        pad_mask:
            Padding mask of shape ``(B, N)``.

        Returns
        -------
        log_probs : torch.Tensor — shape ``(B,)``
        entropy   : torch.Tensor — shape ``(B,)``
        """
        dist = self.get_distribution(tokens, pad_mask)
        log_probs = dist.log_prob(actions).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1)
        return log_probs, entropy

    # ------------------------------------------------------------------
    # Critic helpers
    # ------------------------------------------------------------------

    def get_value(
        self,
        tokens: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Return value estimates for a batch of global entity sequences.

        Parameters
        ----------
        tokens:
            Global entity sequences of shape ``(B, N, token_dim)``.
        pad_mask:
            Padding mask of shape ``(B, N)``.

        Returns
        -------
        values : torch.Tensor — shape ``(B,)``
        """
        enc = self.critic_encoder(tokens, pad_mask)  # (B, d_model)
        return self.critic_head(enc).squeeze(-1)      # (B,)

    # ------------------------------------------------------------------
    # Utility
    # ------------------------------------------------------------------

    def parameter_count(self) -> dict[str, int]:
        """Return a dict with actor and critic parameter counts."""
        actor_params = (
            sum(p.numel() for p in self.actor_encoder.parameters())
            + sum(p.numel() for p in self.actor_head.parameters())
            + self.log_std.numel()
        )
        if self.shared_encoder:
            critic_params = sum(p.numel() for p in self.critic_head.parameters())
        else:
            critic_params = (
                sum(p.numel() for p in self.critic_encoder.parameters())
                + sum(p.numel() for p in self.critic_head.parameters())
            )
        return {
            "actor": actor_params,
            "critic": critic_params,
            "total": actor_params + critic_params,
        }

act(tokens, pad_mask=None, deterministic=False)

Sample (or select deterministically) actions for a batch.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim) or (N, token_dim) for a single sample (batch dimension will be added automatically).

required
pad_mask Optional[Tensor]

Padding mask of shape (B, N) or (N,) for single sample.

None
deterministic bool

Return the distribution mean instead of sampling.

False

Returns:

Name Type Description
actions torch.Tensor — shape ``(B, action_dim)``
log_probs torch.Tensor — shape ``(B,)``
Source code in models/entity_encoder.py
@torch.no_grad()
def act(
    self,
    tokens: torch.Tensor,
    pad_mask: Optional[torch.Tensor] = None,
    deterministic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Sample (or select deterministically) actions for a batch.

    Parameters
    ----------
    tokens:
        Shape ``(B, N, token_dim)`` or ``(N, token_dim)`` for a single
        sample (batch dimension will be added automatically).
    pad_mask:
        Padding mask of shape ``(B, N)`` or ``(N,)`` for single sample.
    deterministic:
        Return the distribution mean instead of sampling.

    Returns
    -------
    actions   : torch.Tensor — shape ``(B, action_dim)``
    log_probs : torch.Tensor — shape ``(B,)``
    """
    squeezed = tokens.dim() == 2
    if squeezed:
        tokens = tokens.unsqueeze(0)
        if pad_mask is not None:
            pad_mask = pad_mask.unsqueeze(0)
    dist = self.get_distribution(tokens, pad_mask)
    actions = dist.mean if deterministic else dist.rsample()
    log_probs = dist.log_prob(actions).sum(dim=-1)
    return actions, log_probs

evaluate_actions(tokens, actions, pad_mask=None)

Evaluate log-probs and entropy for given token sequences and actions.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim).

required
actions Tensor

Shape (B, action_dim).

required
pad_mask Optional[Tensor]

Padding mask of shape (B, N).

None

Returns:

Name Type Description
log_probs torch.Tensor — shape ``(B,)``
entropy torch.Tensor — shape ``(B,)``
Source code in models/entity_encoder.py
def evaluate_actions(
    self,
    tokens: torch.Tensor,
    actions: torch.Tensor,
    pad_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Evaluate log-probs and entropy for given token sequences and actions.

    Parameters
    ----------
    tokens:
        Shape ``(B, N, token_dim)``.
    actions:
        Shape ``(B, action_dim)``.
    pad_mask:
        Padding mask of shape ``(B, N)``.

    Returns
    -------
    log_probs : torch.Tensor — shape ``(B,)``
    entropy   : torch.Tensor — shape ``(B,)``
    """
    dist = self.get_distribution(tokens, pad_mask)
    log_probs = dist.log_prob(actions).sum(dim=-1)
    entropy = dist.entropy().sum(dim=-1)
    return log_probs, entropy

get_distribution(tokens, pad_mask=None)

Compute the action distribution for a batch of entity sequences.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim).

required
pad_mask Optional[Tensor]

Boolean padding mask of shape (B, N).

None

Returns:

Name Type Description
dist :class:`~torch.distributions.Normal`
Source code in models/entity_encoder.py
def get_distribution(
    self,
    tokens: torch.Tensor,
    pad_mask: Optional[torch.Tensor] = None,
) -> Normal:
    """Compute the action distribution for a batch of entity sequences.

    Parameters
    ----------
    tokens:
        Shape ``(B, N, token_dim)``.
    pad_mask:
        Boolean padding mask of shape ``(B, N)``.

    Returns
    -------
    dist : :class:`~torch.distributions.Normal`
    """
    enc = self.actor_encoder(tokens, pad_mask)  # (B, d_model)
    mean = self.actor_head(enc)  # (B, action_dim)
    std = self.log_std.exp().expand_as(mean)
    return Normal(mean, std)

get_value(tokens, pad_mask=None)

Return value estimates for a batch of global entity sequences.

Parameters:

Name Type Description Default
tokens Tensor

Global entity sequences of shape (B, N, token_dim).

required
pad_mask Optional[Tensor]

Padding mask of shape (B, N).

None

Returns:

Name Type Description
values torch.Tensor — shape ``(B,)``
Source code in models/entity_encoder.py
def get_value(
    self,
    tokens: torch.Tensor,
    pad_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Return value estimates for a batch of global entity sequences.

    Parameters
    ----------
    tokens:
        Global entity sequences of shape ``(B, N, token_dim)``.
    pad_mask:
        Padding mask of shape ``(B, N)``.

    Returns
    -------
    values : torch.Tensor — shape ``(B,)``
    """
    enc = self.critic_encoder(tokens, pad_mask)  # (B, d_model)
    return self.critic_head(enc).squeeze(-1)      # (B,)

parameter_count()

Return a dict with actor and critic parameter counts.

Source code in models/entity_encoder.py
def parameter_count(self) -> dict[str, int]:
    """Return a dict with actor and critic parameter counts."""
    actor_params = (
        sum(p.numel() for p in self.actor_encoder.parameters())
        + sum(p.numel() for p in self.actor_head.parameters())
        + self.log_std.numel()
    )
    if self.shared_encoder:
        critic_params = sum(p.numel() for p in self.critic_head.parameters())
    else:
        critic_params = (
            sum(p.numel() for p in self.critic_encoder.parameters())
            + sum(p.numel() for p in self.critic_head.parameters())
        )
    return {
        "actor": actor_params,
        "critic": critic_params,
        "total": actor_params + critic_params,
    }

models.entity_encoder.SpatialPositionalEncoding

Bases: Module

Additive 2-D Fourier positional encoding for entity (x, y) positions.

Computes:

.. code-block:: text

PE(x, y) = concat([sin(2π k x), cos(2π k x),
                   sin(2π k y), cos(2π k y)]  for k in 1..n_freqs)

and projects the resulting 4 * n_freqs-dimensional vector to d_model via a single linear layer.

Parameters:

Name Type Description Default
d_model int

Output dimension (must match the transformer d_model).

required
n_freqs int

Number of frequency bands per axis. Defaults to 8.

8
Source code in models/entity_encoder.py
class SpatialPositionalEncoding(nn.Module):
    """Additive 2-D Fourier positional encoding for entity (x, y) positions.

    Computes:

    .. code-block:: text

        PE(x, y) = concat([sin(2π k x), cos(2π k x),
                           sin(2π k y), cos(2π k y)]  for k in 1..n_freqs)

    and projects the resulting ``4 * n_freqs``-dimensional vector to
    ``d_model`` via a single linear layer.

    Parameters
    ----------
    d_model:
        Output dimension (must match the transformer d_model).
    n_freqs:
        Number of frequency bands per axis.  Defaults to ``8``.
    """

    def __init__(self, d_model: int, n_freqs: int = 8) -> None:
        super().__init__()
        self.n_freqs = n_freqs
        # Frequencies: [1, 2, …, n_freqs]
        freqs = torch.arange(1, n_freqs + 1, dtype=torch.float32)
        self.register_buffer("freqs", freqs)  # (n_freqs,)
        self.proj = nn.Linear(4 * n_freqs, d_model)

    def forward(self, xy: torch.Tensor) -> torch.Tensor:
        """Compute positional embeddings for a batch of (x, y) pairs.

        Parameters
        ----------
        xy:
            Positions of shape ``(..., 2)`` in ``[0, 1]``.

        Returns
        -------
        pe : torch.Tensor — shape ``(..., d_model)``
        """
        x = xy[..., 0:1]  # (..., 1)
        y = xy[..., 1:2]  # (..., 1)
        freqs = self.freqs  # (n_freqs,)
        # (..., n_freqs)
        ax = 2 * math.pi * x * freqs
        ay = 2 * math.pi * y * freqs
        pe = torch.cat([torch.sin(ax), torch.cos(ax),
                        torch.sin(ay), torch.cos(ay)], dim=-1)  # (..., 4*n_freqs)
        return self.proj(pe)  # (..., d_model)

forward(xy)

Compute positional embeddings for a batch of (x, y) pairs.

Parameters:

Name Type Description Default
xy Tensor

Positions of shape (..., 2) in [0, 1].

required

Returns:

Name Type Description
pe torch.Tensor — shape ``(..., d_model)``
Source code in models/entity_encoder.py
def forward(self, xy: torch.Tensor) -> torch.Tensor:
    """Compute positional embeddings for a batch of (x, y) pairs.

    Parameters
    ----------
    xy:
        Positions of shape ``(..., 2)`` in ``[0, 1]``.

    Returns
    -------
    pe : torch.Tensor — shape ``(..., d_model)``
    """
    x = xy[..., 0:1]  # (..., 1)
    y = xy[..., 1:2]  # (..., 1)
    freqs = self.freqs  # (n_freqs,)
    # (..., n_freqs)
    ax = 2 * math.pi * x * freqs
    ay = 2 * math.pi * y * freqs
    pe = torch.cat([torch.sin(ax), torch.cos(ax),
                    torch.sin(ay), torch.cos(ay)], dim=-1)  # (..., 4*n_freqs)
    return self.proj(pe)  # (..., d_model)

Recurrent policy (LSTM)

models.recurrent_policy.RecurrentActorCriticPolicy

Bases: Module

Actor-critic policy with LSTM temporal memory.

Both actor and critic run through the same :class:RecurrentEntityEncoder (weight-sharing optional). Hidden states are passed in/out explicitly so the caller controls episode boundaries and checkpointing.

Parameters:

Name Type Description Default
token_dim int

Entity token dimensionality.

ENTITY_TOKEN_DIM
action_dim int

Continuous action space dimensionality.

3
d_model int

Transformer internal dimension.

64
n_heads int

Attention heads in the entity encoder.

4
n_layers int

Transformer encoder layers.

2
lstm_hidden_size int

LSTM hidden state dimensionality.

128
lstm_num_layers int

Number of stacked LSTM layers.

1
actor_hidden_sizes Tuple[int, ...]

MLP hidden sizes on top of the LSTM output for the actor head.

(128, 64)
critic_hidden_sizes Tuple[int, ...]

MLP hidden sizes on top of the LSTM output for the critic head.

(128, 64)
shared_encoder bool

When True (default), actor and critic share the same :class:RecurrentEntityEncoder weights.

True
dropout float

Dropout probability.

0.0
use_spatial_pe bool

Enable 2-D Fourier positional encoding.

True
Source code in models/recurrent_policy.py
class RecurrentActorCriticPolicy(nn.Module):
    """Actor-critic policy with LSTM temporal memory.

    Both actor and critic run through the same :class:`RecurrentEntityEncoder`
    (weight-sharing optional).  Hidden states are passed in/out explicitly so
    the caller controls episode boundaries and checkpointing.

    Parameters
    ----------
    token_dim:
        Entity token dimensionality.
    action_dim:
        Continuous action space dimensionality.
    d_model:
        Transformer internal dimension.
    n_heads:
        Attention heads in the entity encoder.
    n_layers:
        Transformer encoder layers.
    lstm_hidden_size:
        LSTM hidden state dimensionality.
    lstm_num_layers:
        Number of stacked LSTM layers.
    actor_hidden_sizes:
        MLP hidden sizes on top of the LSTM output for the actor head.
    critic_hidden_sizes:
        MLP hidden sizes on top of the LSTM output for the critic head.
    shared_encoder:
        When ``True`` (default), actor and critic share the same
        :class:`RecurrentEntityEncoder` weights.
    dropout:
        Dropout probability.
    use_spatial_pe:
        Enable 2-D Fourier positional encoding.
    """

    def __init__(
        self,
        token_dim: int = ENTITY_TOKEN_DIM,
        action_dim: int = 3,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        lstm_hidden_size: int = 128,
        lstm_num_layers: int = 1,
        actor_hidden_sizes: Tuple[int, ...] = (128, 64),
        critic_hidden_sizes: Tuple[int, ...] = (128, 64),
        shared_encoder: bool = True,
        dropout: float = 0.0,
        use_spatial_pe: bool = True,
    ) -> None:
        super().__init__()
        self.token_dim = token_dim
        self.action_dim = action_dim
        self.d_model = d_model
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_num_layers = lstm_num_layers
        self.shared_encoder = shared_encoder

        # Build actor recurrent encoder
        self.actor_encoder = RecurrentEntityEncoder(
            token_dim=token_dim,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            lstm_hidden_size=lstm_hidden_size,
            lstm_num_layers=lstm_num_layers,
            dropout=dropout,
            use_spatial_pe=use_spatial_pe,
        )

        if shared_encoder:
            self.critic_encoder = self.actor_encoder
        else:
            self.critic_encoder = RecurrentEntityEncoder(
                token_dim=token_dim,
                d_model=d_model,
                n_heads=n_heads,
                n_layers=n_layers,
                lstm_hidden_size=lstm_hidden_size,
                lstm_num_layers=lstm_num_layers,
                dropout=dropout,
                use_spatial_pe=use_spatial_pe,
            )

        # Actor head: lstm_out → action mean
        self.actor_head = _build_mlp(lstm_hidden_size, actor_hidden_sizes, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))

        # Critic head: lstm_out → scalar value
        self.critic_head = _build_mlp(lstm_hidden_size, critic_hidden_sizes, 1)

    # ------------------------------------------------------------------
    # Hidden-state utilities
    # ------------------------------------------------------------------

    def initial_state(
        self,
        batch_size: int = 1,
        device: Optional[torch.device] = None,
    ) -> LSTMHiddenState:
        """Return a zero-initialised LSTM hidden state.

        Call at the start of each episode to reset temporal memory.

        Parameters
        ----------
        batch_size:
            Number of parallel environments / samples.
        device:
            Target device.
        """
        return LSTMHiddenState.zeros(
            self.lstm_num_layers, self.lstm_hidden_size, batch_size, device
        )

    # ------------------------------------------------------------------
    # Single-step inference
    # ------------------------------------------------------------------

    def get_distribution(
        self,
        tokens: torch.Tensor,
        hx: LSTMHiddenState,
        pad_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[Normal, LSTMHiddenState]:
        """Compute action distribution for a single timestep.

        Parameters
        ----------
        tokens:
            Shape ``(B, N, token_dim)`` or ``(N, token_dim)`` for unbatched.
        hx:
            Current LSTM hidden state.
        pad_mask:
            Boolean padding mask ``(B, N)`` or ``(N,)`` for unbatched.

        Returns
        -------
        dist : :class:`~torch.distributions.Normal`
        new_hx : :class:`LSTMHiddenState`
        """
        squeezed = tokens.dim() == 2
        if squeezed:
            tokens = tokens.unsqueeze(0)
            if pad_mask is not None:
                pad_mask = pad_mask.unsqueeze(0)

        out, new_hx = self.actor_encoder(tokens, hx, pad_mask)  # (B, hidden)
        mean = self.actor_head(out)  # (B, action_dim)
        std = self.log_std.exp().expand_as(mean)
        return Normal(mean, std), new_hx

    @torch.no_grad()
    def act(
        self,
        tokens: torch.Tensor,
        hx: LSTMHiddenState,
        pad_mask: Optional[torch.Tensor] = None,
        deterministic: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, LSTMHiddenState]:
        """Sample (or select deterministically) an action for a single step.

        Parameters
        ----------
        tokens:
            Shape ``(B, N, token_dim)`` or ``(N, token_dim)`` for single sample.
        hx:
            Current LSTM hidden state.
        pad_mask:
            Padding mask.
        deterministic:
            Return the distribution mean instead of sampling.

        Returns
        -------
        actions   : torch.Tensor — shape ``(B, action_dim)``
        log_probs : torch.Tensor — shape ``(B,)``
        new_hx    : :class:`LSTMHiddenState` — updated hidden state
        """
        dist, new_hx = self.get_distribution(tokens, hx, pad_mask)
        actions = dist.mean if deterministic else dist.rsample()
        log_probs = dist.log_prob(actions).sum(dim=-1)
        return actions, log_probs, new_hx

    @torch.no_grad()
    def get_value(
        self,
        tokens: torch.Tensor,
        hx: LSTMHiddenState,
        pad_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, LSTMHiddenState]:
        """Compute value estimates for a single timestep.

        Parameters
        ----------
        tokens:
            Shape ``(B, N, token_dim)`` or ``(N, token_dim)`` for single sample.
        hx:
            Current LSTM hidden state.
        pad_mask:
            Padding mask.

        Returns
        -------
        values  : torch.Tensor — shape ``(B,)``
        new_hx  : :class:`LSTMHiddenState`
        """
        squeezed = tokens.dim() == 2
        if squeezed:
            tokens = tokens.unsqueeze(0)
            if pad_mask is not None:
                pad_mask = pad_mask.unsqueeze(0)

        out, new_hx = self.critic_encoder(tokens, hx, pad_mask)  # (B, hidden)
        values = self.critic_head(out).squeeze(-1)  # (B,)
        return values, new_hx

    # ------------------------------------------------------------------
    # Sequence-level training (BPTT)
    # ------------------------------------------------------------------

    def evaluate_actions(
        self,
        tokens_seq: torch.Tensor,
        hx: LSTMHiddenState,
        actions_seq: torch.Tensor,
        pad_mask_seq: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Evaluate log-probs, entropy and values for a sequence of actions.

        Used during PPO update to compute the ratio ``π_new / π_old``.

        Parameters
        ----------
        tokens_seq:
            Entity token sequences of shape ``(B, T, N, token_dim)``.
        hx:
            Initial LSTM hidden state at the start of the sequence.
        actions_seq:
            Actions to evaluate, shape ``(B, T, action_dim)``.
        pad_mask_seq:
            Padding mask of shape ``(B, T, N)`` or ``None``.

        Returns
        -------
        log_probs : torch.Tensor — shape ``(B, T)``
        entropy   : torch.Tensor — shape ``(B, T)``
        values    : torch.Tensor — shape ``(B, T)``
        """
        B, T, N, _ = tokens_seq.shape

        # Actor forward through the sequence
        actor_out, _ = self.actor_encoder.forward_sequence(
            tokens_seq, hx, pad_mask_seq
        )  # (B, T, hidden)
        mean = self.actor_head(actor_out)  # (B, T, action_dim)
        std = self.log_std.exp().unsqueeze(0).unsqueeze(0).expand(B, T, -1)
        dist = Normal(mean, std)
        log_probs = dist.log_prob(actions_seq).sum(dim=-1)  # (B, T)
        entropy = dist.entropy().sum(dim=-1)  # (B, T)

        # Critic forward — separate pass if encoders are not shared
        if self.shared_encoder:
            critic_out = actor_out  # reuse
        else:
            critic_out, _ = self.critic_encoder.forward_sequence(
                tokens_seq, hx, pad_mask_seq
            )  # (B, T, hidden)
        values = self.critic_head(critic_out).squeeze(-1)  # (B, T)

        return log_probs, entropy, values

    # ------------------------------------------------------------------
    # Checkpointing
    # ------------------------------------------------------------------

    def save_checkpoint(self, path: str | Path) -> None:
        """Persist model weights to *path* (``torch.save`` format).

        Parameters
        ----------
        path:
            Destination file path.  Parent directories are created if needed.
        """
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(self.state_dict(), path)

    @classmethod
    def load_checkpoint(
        cls,
        path: str | Path,
        device: Optional[torch.device] = None,
        **kwargs,  # forwarded to __init__
    ) -> "RecurrentActorCriticPolicy":
        """Restore a policy from a checkpoint created by :meth:`save_checkpoint`.

        Parameters
        ----------
        path:
            Path to the checkpoint file.
        device:
            Device to load weights onto.
        **kwargs:
            Constructor arguments — must match those used when the checkpoint
            was saved.

        Returns
        -------
        policy : :class:`RecurrentActorCriticPolicy`
        """
        policy = cls(**kwargs)
        state_dict = torch.load(path, map_location=device, weights_only=True)
        policy.load_state_dict(state_dict)
        if device is not None:
            policy.to(device)
        return policy

    # ------------------------------------------------------------------
    # Utility
    # ------------------------------------------------------------------

    def parameter_count(self) -> Dict[str, int]:
        """Return a dict with actor and critic parameter counts."""
        actor_params = (
            sum(p.numel() for p in self.actor_encoder.parameters())
            + sum(p.numel() for p in self.actor_head.parameters())
            + self.log_std.numel()
        )
        if self.shared_encoder:
            critic_params = sum(p.numel() for p in self.critic_head.parameters())
        else:
            critic_params = (
                sum(p.numel() for p in self.critic_encoder.parameters())
                + sum(p.numel() for p in self.critic_head.parameters())
            )
        return {
            "actor": actor_params,
            "critic": critic_params,
            "total": actor_params + critic_params,
        }

act(tokens, hx, pad_mask=None, deterministic=False)

Sample (or select deterministically) an action for a single step.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim) or (N, token_dim) for single sample.

required
hx LSTMHiddenState

Current LSTM hidden state.

required
pad_mask Optional[Tensor]

Padding mask.

None
deterministic bool

Return the distribution mean instead of sampling.

False

Returns:

Name Type Description
actions torch.Tensor — shape ``(B, action_dim)``
log_probs torch.Tensor — shape ``(B,)``
new_hx :class:`LSTMHiddenState` — updated hidden state
Source code in models/recurrent_policy.py
@torch.no_grad()
def act(
    self,
    tokens: torch.Tensor,
    hx: LSTMHiddenState,
    pad_mask: Optional[torch.Tensor] = None,
    deterministic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, LSTMHiddenState]:
    """Sample (or select deterministically) an action for a single step.

    Parameters
    ----------
    tokens:
        Shape ``(B, N, token_dim)`` or ``(N, token_dim)`` for single sample.
    hx:
        Current LSTM hidden state.
    pad_mask:
        Padding mask.
    deterministic:
        Return the distribution mean instead of sampling.

    Returns
    -------
    actions   : torch.Tensor — shape ``(B, action_dim)``
    log_probs : torch.Tensor — shape ``(B,)``
    new_hx    : :class:`LSTMHiddenState` — updated hidden state
    """
    dist, new_hx = self.get_distribution(tokens, hx, pad_mask)
    actions = dist.mean if deterministic else dist.rsample()
    log_probs = dist.log_prob(actions).sum(dim=-1)
    return actions, log_probs, new_hx

evaluate_actions(tokens_seq, hx, actions_seq, pad_mask_seq=None)

Evaluate log-probs, entropy and values for a sequence of actions.

Used during PPO update to compute the ratio π_new / π_old.

Parameters:

Name Type Description Default
tokens_seq Tensor

Entity token sequences of shape (B, T, N, token_dim).

required
hx LSTMHiddenState

Initial LSTM hidden state at the start of the sequence.

required
actions_seq Tensor

Actions to evaluate, shape (B, T, action_dim).

required
pad_mask_seq Optional[Tensor]

Padding mask of shape (B, T, N) or None.

None

Returns:

Name Type Description
log_probs torch.Tensor — shape ``(B, T)``
entropy torch.Tensor — shape ``(B, T)``
values torch.Tensor — shape ``(B, T)``
Source code in models/recurrent_policy.py
def evaluate_actions(
    self,
    tokens_seq: torch.Tensor,
    hx: LSTMHiddenState,
    actions_seq: torch.Tensor,
    pad_mask_seq: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Evaluate log-probs, entropy and values for a sequence of actions.

    Used during PPO update to compute the ratio ``π_new / π_old``.

    Parameters
    ----------
    tokens_seq:
        Entity token sequences of shape ``(B, T, N, token_dim)``.
    hx:
        Initial LSTM hidden state at the start of the sequence.
    actions_seq:
        Actions to evaluate, shape ``(B, T, action_dim)``.
    pad_mask_seq:
        Padding mask of shape ``(B, T, N)`` or ``None``.

    Returns
    -------
    log_probs : torch.Tensor — shape ``(B, T)``
    entropy   : torch.Tensor — shape ``(B, T)``
    values    : torch.Tensor — shape ``(B, T)``
    """
    B, T, N, _ = tokens_seq.shape

    # Actor forward through the sequence
    actor_out, _ = self.actor_encoder.forward_sequence(
        tokens_seq, hx, pad_mask_seq
    )  # (B, T, hidden)
    mean = self.actor_head(actor_out)  # (B, T, action_dim)
    std = self.log_std.exp().unsqueeze(0).unsqueeze(0).expand(B, T, -1)
    dist = Normal(mean, std)
    log_probs = dist.log_prob(actions_seq).sum(dim=-1)  # (B, T)
    entropy = dist.entropy().sum(dim=-1)  # (B, T)

    # Critic forward — separate pass if encoders are not shared
    if self.shared_encoder:
        critic_out = actor_out  # reuse
    else:
        critic_out, _ = self.critic_encoder.forward_sequence(
            tokens_seq, hx, pad_mask_seq
        )  # (B, T, hidden)
    values = self.critic_head(critic_out).squeeze(-1)  # (B, T)

    return log_probs, entropy, values

get_distribution(tokens, hx, pad_mask=None)

Compute action distribution for a single timestep.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim) or (N, token_dim) for unbatched.

required
hx LSTMHiddenState

Current LSTM hidden state.

required
pad_mask Optional[Tensor]

Boolean padding mask (B, N) or (N,) for unbatched.

None

Returns:

Name Type Description
dist :class:`~torch.distributions.Normal`
new_hx :class:`LSTMHiddenState`
Source code in models/recurrent_policy.py
def get_distribution(
    self,
    tokens: torch.Tensor,
    hx: LSTMHiddenState,
    pad_mask: Optional[torch.Tensor] = None,
) -> Tuple[Normal, LSTMHiddenState]:
    """Compute action distribution for a single timestep.

    Parameters
    ----------
    tokens:
        Shape ``(B, N, token_dim)`` or ``(N, token_dim)`` for unbatched.
    hx:
        Current LSTM hidden state.
    pad_mask:
        Boolean padding mask ``(B, N)`` or ``(N,)`` for unbatched.

    Returns
    -------
    dist : :class:`~torch.distributions.Normal`
    new_hx : :class:`LSTMHiddenState`
    """
    squeezed = tokens.dim() == 2
    if squeezed:
        tokens = tokens.unsqueeze(0)
        if pad_mask is not None:
            pad_mask = pad_mask.unsqueeze(0)

    out, new_hx = self.actor_encoder(tokens, hx, pad_mask)  # (B, hidden)
    mean = self.actor_head(out)  # (B, action_dim)
    std = self.log_std.exp().expand_as(mean)
    return Normal(mean, std), new_hx

get_value(tokens, hx, pad_mask=None)

Compute value estimates for a single timestep.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim) or (N, token_dim) for single sample.

required
hx LSTMHiddenState

Current LSTM hidden state.

required
pad_mask Optional[Tensor]

Padding mask.

None

Returns:

Name Type Description
values torch.Tensor — shape ``(B,)``
new_hx :class:`LSTMHiddenState`
Source code in models/recurrent_policy.py
@torch.no_grad()
def get_value(
    self,
    tokens: torch.Tensor,
    hx: LSTMHiddenState,
    pad_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, LSTMHiddenState]:
    """Compute value estimates for a single timestep.

    Parameters
    ----------
    tokens:
        Shape ``(B, N, token_dim)`` or ``(N, token_dim)`` for single sample.
    hx:
        Current LSTM hidden state.
    pad_mask:
        Padding mask.

    Returns
    -------
    values  : torch.Tensor — shape ``(B,)``
    new_hx  : :class:`LSTMHiddenState`
    """
    squeezed = tokens.dim() == 2
    if squeezed:
        tokens = tokens.unsqueeze(0)
        if pad_mask is not None:
            pad_mask = pad_mask.unsqueeze(0)

    out, new_hx = self.critic_encoder(tokens, hx, pad_mask)  # (B, hidden)
    values = self.critic_head(out).squeeze(-1)  # (B,)
    return values, new_hx

initial_state(batch_size=1, device=None)

Return a zero-initialised LSTM hidden state.

Call at the start of each episode to reset temporal memory.

Parameters:

Name Type Description Default
batch_size int

Number of parallel environments / samples.

1
device Optional[device]

Target device.

None
Source code in models/recurrent_policy.py
def initial_state(
    self,
    batch_size: int = 1,
    device: Optional[torch.device] = None,
) -> LSTMHiddenState:
    """Return a zero-initialised LSTM hidden state.

    Call at the start of each episode to reset temporal memory.

    Parameters
    ----------
    batch_size:
        Number of parallel environments / samples.
    device:
        Target device.
    """
    return LSTMHiddenState.zeros(
        self.lstm_num_layers, self.lstm_hidden_size, batch_size, device
    )

load_checkpoint(path, device=None, **kwargs) classmethod

Restore a policy from a checkpoint created by :meth:save_checkpoint.

Parameters:

Name Type Description Default
path str | Path

Path to the checkpoint file.

required
device Optional[device]

Device to load weights onto.

None
**kwargs

Constructor arguments — must match those used when the checkpoint was saved.

{}

Returns:

Name Type Description
policy :class:`RecurrentActorCriticPolicy`
Source code in models/recurrent_policy.py
@classmethod
def load_checkpoint(
    cls,
    path: str | Path,
    device: Optional[torch.device] = None,
    **kwargs,  # forwarded to __init__
) -> "RecurrentActorCriticPolicy":
    """Restore a policy from a checkpoint created by :meth:`save_checkpoint`.

    Parameters
    ----------
    path:
        Path to the checkpoint file.
    device:
        Device to load weights onto.
    **kwargs:
        Constructor arguments — must match those used when the checkpoint
        was saved.

    Returns
    -------
    policy : :class:`RecurrentActorCriticPolicy`
    """
    policy = cls(**kwargs)
    state_dict = torch.load(path, map_location=device, weights_only=True)
    policy.load_state_dict(state_dict)
    if device is not None:
        policy.to(device)
    return policy

parameter_count()

Return a dict with actor and critic parameter counts.

Source code in models/recurrent_policy.py
def parameter_count(self) -> Dict[str, int]:
    """Return a dict with actor and critic parameter counts."""
    actor_params = (
        sum(p.numel() for p in self.actor_encoder.parameters())
        + sum(p.numel() for p in self.actor_head.parameters())
        + self.log_std.numel()
    )
    if self.shared_encoder:
        critic_params = sum(p.numel() for p in self.critic_head.parameters())
    else:
        critic_params = (
            sum(p.numel() for p in self.critic_encoder.parameters())
            + sum(p.numel() for p in self.critic_head.parameters())
        )
    return {
        "actor": actor_params,
        "critic": critic_params,
        "total": actor_params + critic_params,
    }

save_checkpoint(path)

Persist model weights to path (torch.save format).

Parameters:

Name Type Description Default
path str | Path

Destination file path. Parent directories are created if needed.

required
Source code in models/recurrent_policy.py
def save_checkpoint(self, path: str | Path) -> None:
    """Persist model weights to *path* (``torch.save`` format).

    Parameters
    ----------
    path:
        Destination file path.  Parent directories are created if needed.
    """
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(self.state_dict(), path)

models.recurrent_policy.RecurrentEntityEncoder

Bases: Module

Entity encoder followed by a multi-layer LSTM for temporal memory.

The entity encoder reduces a variable-length set of entity tokens to a fixed-size pooled vector. The LSTM then integrates this encoding across timesteps, maintaining an internal model of unobserved unit positions.

Architecture::

tokens (B, N, token_dim)
      │
EntityEncoder  →  enc (B, d_model)
      │
nn.LSTM(d_model → hidden_size, num_layers)
      │  ← (h_t, c_t)  in / out
lstm_out (B, hidden_size)

Parameters:

Name Type Description Default
token_dim int

Entity token dimensionality.

ENTITY_TOKEN_DIM
d_model int

Transformer internal dimension (EntityEncoder output).

64
n_heads int

Attention heads for the entity encoder.

4
n_layers int

Transformer encoder layers.

2
lstm_hidden_size int

LSTM hidden state dimensionality.

128
lstm_num_layers int

Number of stacked LSTM layers.

1
dropout float

Dropout probability applied inside transformer and LSTM.

0.0
use_spatial_pe bool

Enable 2-D Fourier positional encoding on entity tokens.

True
Source code in models/recurrent_policy.py
class RecurrentEntityEncoder(nn.Module):
    """Entity encoder followed by a multi-layer LSTM for temporal memory.

    The entity encoder reduces a variable-length set of entity tokens to a
    fixed-size pooled vector.  The LSTM then integrates this encoding across
    timesteps, maintaining an internal model of unobserved unit positions.

    Architecture::

        tokens (B, N, token_dim)

        EntityEncoder  →  enc (B, d_model)

        nn.LSTM(d_model → hidden_size, num_layers)
              │  ← (h_t, c_t)  in / out
        lstm_out (B, hidden_size)

    Parameters
    ----------
    token_dim:
        Entity token dimensionality.
    d_model:
        Transformer internal dimension (EntityEncoder output).
    n_heads:
        Attention heads for the entity encoder.
    n_layers:
        Transformer encoder layers.
    lstm_hidden_size:
        LSTM hidden state dimensionality.
    lstm_num_layers:
        Number of stacked LSTM layers.
    dropout:
        Dropout probability applied inside transformer and LSTM.
    use_spatial_pe:
        Enable 2-D Fourier positional encoding on entity tokens.
    """

    def __init__(
        self,
        token_dim: int = ENTITY_TOKEN_DIM,
        d_model: int = 64,
        n_heads: int = 4,
        n_layers: int = 2,
        lstm_hidden_size: int = 128,
        lstm_num_layers: int = 1,
        dropout: float = 0.0,
        use_spatial_pe: bool = True,
    ) -> None:
        super().__init__()
        self.token_dim = token_dim
        self.d_model = d_model
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_num_layers = lstm_num_layers

        self.entity_encoder = EntityEncoder(
            token_dim=token_dim,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            dropout=dropout,
            use_spatial_pe=use_spatial_pe,
        )

        lstm_dropout = dropout if lstm_num_layers > 1 else 0.0
        self.lstm = nn.LSTM(
            input_size=d_model,
            hidden_size=lstm_hidden_size,
            num_layers=lstm_num_layers,
            batch_first=True,
            dropout=lstm_dropout,
        )

    @property
    def output_dim(self) -> int:
        """Dimensionality of the LSTM output vector."""
        return self.lstm_hidden_size

    def initial_state(
        self,
        batch_size: int = 1,
        device: Optional[torch.device] = None,
    ) -> LSTMHiddenState:
        """Return a zero-initialised hidden state for *batch_size* samples."""
        return LSTMHiddenState.zeros(
            self.lstm_num_layers, self.lstm_hidden_size, batch_size, device
        )

    def forward(
        self,
        tokens: torch.Tensor,
        hx: LSTMHiddenState,
        pad_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, LSTMHiddenState]:
        """Encode a single timestep of entity tokens.

        Parameters
        ----------
        tokens:
            Shape ``(B, N, token_dim)`` — entity tokens for one timestep.
        hx:
            Current LSTM hidden state.
        pad_mask:
            Boolean padding mask of shape ``(B, N)``.

        Returns
        -------
        out : torch.Tensor — shape ``(B, lstm_hidden_size)``
            LSTM output for this timestep.
        new_hx : LSTMHiddenState
            Updated hidden and cell states.
        """
        enc = self.entity_encoder(tokens, pad_mask)  # (B, d_model)
        # LSTM expects (B, seq_len, input_size); seq_len = 1 for single-step
        lstm_in = enc.unsqueeze(1)  # (B, 1, d_model)
        lstm_out, (h_new, c_new) = self.lstm(lstm_in, hx.as_tuple())
        out = lstm_out.squeeze(1)  # (B, lstm_hidden_size)
        return out, LSTMHiddenState(h=h_new, c=c_new)

    def forward_sequence(
        self,
        tokens_seq: torch.Tensor,
        hx: LSTMHiddenState,
        pad_mask_seq: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, LSTMHiddenState]:
        """Encode a *sequence* of timesteps for BPTT during training.

        Parameters
        ----------
        tokens_seq:
            Shape ``(B, T, N, token_dim)`` — entity token sequences.
        hx:
            Initial hidden state at the start of the sequence.
        pad_mask_seq:
            Padding mask of shape ``(B, T, N)`` or ``None``.

        Returns
        -------
        out_seq : torch.Tensor — shape ``(B, T, lstm_hidden_size)``
            LSTM outputs for every timestep.
        new_hx : LSTMHiddenState
            Hidden state after the last timestep.
        """
        B, T, N, _ = tokens_seq.shape

        # Flatten (B, T) → (B*T) for the entity encoder
        tokens_flat = tokens_seq.reshape(B * T, N, -1)
        pad_flat = (
            pad_mask_seq.reshape(B * T, N) if pad_mask_seq is not None else None
        )
        enc_flat = self.entity_encoder(tokens_flat, pad_flat)  # (B*T, d_model)
        enc_seq = enc_flat.reshape(B, T, self.d_model)  # (B, T, d_model)

        lstm_out, (h_new, c_new) = self.lstm(enc_seq, hx.as_tuple())
        return lstm_out, LSTMHiddenState(h=h_new, c=c_new)

output_dim property

Dimensionality of the LSTM output vector.

forward(tokens, hx, pad_mask=None)

Encode a single timestep of entity tokens.

Parameters:

Name Type Description Default
tokens Tensor

Shape (B, N, token_dim) — entity tokens for one timestep.

required
hx LSTMHiddenState

Current LSTM hidden state.

required
pad_mask Optional[Tensor]

Boolean padding mask of shape (B, N).

None

Returns:

Name Type Description
out torch.Tensor — shape ``(B, lstm_hidden_size)``

LSTM output for this timestep.

new_hx LSTMHiddenState

Updated hidden and cell states.

Source code in models/recurrent_policy.py
def forward(
    self,
    tokens: torch.Tensor,
    hx: LSTMHiddenState,
    pad_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, LSTMHiddenState]:
    """Encode a single timestep of entity tokens.

    Parameters
    ----------
    tokens:
        Shape ``(B, N, token_dim)`` — entity tokens for one timestep.
    hx:
        Current LSTM hidden state.
    pad_mask:
        Boolean padding mask of shape ``(B, N)``.

    Returns
    -------
    out : torch.Tensor — shape ``(B, lstm_hidden_size)``
        LSTM output for this timestep.
    new_hx : LSTMHiddenState
        Updated hidden and cell states.
    """
    enc = self.entity_encoder(tokens, pad_mask)  # (B, d_model)
    # LSTM expects (B, seq_len, input_size); seq_len = 1 for single-step
    lstm_in = enc.unsqueeze(1)  # (B, 1, d_model)
    lstm_out, (h_new, c_new) = self.lstm(lstm_in, hx.as_tuple())
    out = lstm_out.squeeze(1)  # (B, lstm_hidden_size)
    return out, LSTMHiddenState(h=h_new, c=c_new)

forward_sequence(tokens_seq, hx, pad_mask_seq=None)

Encode a sequence of timesteps for BPTT during training.

Parameters:

Name Type Description Default
tokens_seq Tensor

Shape (B, T, N, token_dim) — entity token sequences.

required
hx LSTMHiddenState

Initial hidden state at the start of the sequence.

required
pad_mask_seq Optional[Tensor]

Padding mask of shape (B, T, N) or None.

None

Returns:

Name Type Description
out_seq torch.Tensor — shape ``(B, T, lstm_hidden_size)``

LSTM outputs for every timestep.

new_hx LSTMHiddenState

Hidden state after the last timestep.

Source code in models/recurrent_policy.py
def forward_sequence(
    self,
    tokens_seq: torch.Tensor,
    hx: LSTMHiddenState,
    pad_mask_seq: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, LSTMHiddenState]:
    """Encode a *sequence* of timesteps for BPTT during training.

    Parameters
    ----------
    tokens_seq:
        Shape ``(B, T, N, token_dim)`` — entity token sequences.
    hx:
        Initial hidden state at the start of the sequence.
    pad_mask_seq:
        Padding mask of shape ``(B, T, N)`` or ``None``.

    Returns
    -------
    out_seq : torch.Tensor — shape ``(B, T, lstm_hidden_size)``
        LSTM outputs for every timestep.
    new_hx : LSTMHiddenState
        Hidden state after the last timestep.
    """
    B, T, N, _ = tokens_seq.shape

    # Flatten (B, T) → (B*T) for the entity encoder
    tokens_flat = tokens_seq.reshape(B * T, N, -1)
    pad_flat = (
        pad_mask_seq.reshape(B * T, N) if pad_mask_seq is not None else None
    )
    enc_flat = self.entity_encoder(tokens_flat, pad_flat)  # (B*T, d_model)
    enc_seq = enc_flat.reshape(B, T, self.d_model)  # (B, T, d_model)

    lstm_out, (h_new, c_new) = self.lstm(enc_seq, hx.as_tuple())
    return lstm_out, LSTMHiddenState(h=h_new, c=c_new)

initial_state(batch_size=1, device=None)

Return a zero-initialised hidden state for batch_size samples.

Source code in models/recurrent_policy.py
def initial_state(
    self,
    batch_size: int = 1,
    device: Optional[torch.device] = None,
) -> LSTMHiddenState:
    """Return a zero-initialised hidden state for *batch_size* samples."""
    return LSTMHiddenState.zeros(
        self.lstm_num_layers, self.lstm_hidden_size, batch_size, device
    )

models.recurrent_policy.LSTMHiddenState dataclass

Container for LSTM (h, c) states with episode-reset utilities.

Both tensors have shape (num_layers, batch, hidden_size).

Parameters:

Name Type Description Default
h Tensor

Hidden state tensor.

required
c Tensor

Cell state tensor.

required
Source code in models/recurrent_policy.py
@dataclass
class LSTMHiddenState:
    """Container for LSTM ``(h, c)`` states with episode-reset utilities.

    Both tensors have shape ``(num_layers, batch, hidden_size)``.

    Parameters
    ----------
    h:
        Hidden state tensor.
    c:
        Cell state tensor.
    """

    h: torch.Tensor  # (num_layers, batch, hidden_size)
    c: torch.Tensor  # (num_layers, batch, hidden_size)

    def detach(self) -> "LSTMHiddenState":
        """Return a new :class:`LSTMHiddenState` with detached tensors."""
        return LSTMHiddenState(h=self.h.detach(), c=self.c.detach())

    def to(self, device: torch.device) -> "LSTMHiddenState":
        """Move states to *device*."""
        return LSTMHiddenState(h=self.h.to(device), c=self.c.to(device))

    def reset_at(self, done_mask: torch.Tensor) -> "LSTMHiddenState":
        """Zero out hidden states for episodes that have ended.

        Parameters
        ----------
        done_mask:
            Boolean tensor of shape ``(batch,)`` — ``True`` where the episode
            ended and the hidden state should be cleared.

        Returns
        -------
        LSTMHiddenState with states zeroed at ``done_mask`` positions.
        """
        mask = done_mask.to(self.h.device)  # (batch,)
        # Broadcast mask over (num_layers, batch, hidden_size)
        mask_3d = mask.unsqueeze(0).unsqueeze(-1)  # (1, batch, 1)
        h_new = self.h.masked_fill(mask_3d, 0.0)
        c_new = self.c.masked_fill(mask_3d, 0.0)
        return LSTMHiddenState(h=h_new, c=c_new)

    @classmethod
    def zeros(
        cls,
        num_layers: int,
        hidden_size: int,
        batch_size: int = 1,
        device: Optional[torch.device] = None,
    ) -> "LSTMHiddenState":
        """Create a zero-initialised hidden state.

        Parameters
        ----------
        num_layers:
            Number of LSTM layers.
        hidden_size:
            LSTM hidden dimension.
        batch_size:
            Batch dimension.
        device:
            Target device (defaults to CPU).
        """
        if device is None:
            device = torch.device("cpu")
        shape = (num_layers, batch_size, hidden_size)
        return cls(
            h=torch.zeros(*shape, device=device),
            c=torch.zeros(*shape, device=device),
        )

    def as_tuple(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return ``(h, c)`` tuple accepted by ``nn.LSTM``."""
        return (self.h, self.c)

as_tuple()

Return (h, c) tuple accepted by nn.LSTM.

Source code in models/recurrent_policy.py
def as_tuple(self) -> Tuple[torch.Tensor, torch.Tensor]:
    """Return ``(h, c)`` tuple accepted by ``nn.LSTM``."""
    return (self.h, self.c)

detach()

Return a new :class:LSTMHiddenState with detached tensors.

Source code in models/recurrent_policy.py
def detach(self) -> "LSTMHiddenState":
    """Return a new :class:`LSTMHiddenState` with detached tensors."""
    return LSTMHiddenState(h=self.h.detach(), c=self.c.detach())

reset_at(done_mask)

Zero out hidden states for episodes that have ended.

Parameters:

Name Type Description Default
done_mask Tensor

Boolean tensor of shape (batch,)True where the episode ended and the hidden state should be cleared.

required

Returns:

Type Description
LSTMHiddenState with states zeroed at ``done_mask`` positions.
Source code in models/recurrent_policy.py
def reset_at(self, done_mask: torch.Tensor) -> "LSTMHiddenState":
    """Zero out hidden states for episodes that have ended.

    Parameters
    ----------
    done_mask:
        Boolean tensor of shape ``(batch,)`` — ``True`` where the episode
        ended and the hidden state should be cleared.

    Returns
    -------
    LSTMHiddenState with states zeroed at ``done_mask`` positions.
    """
    mask = done_mask.to(self.h.device)  # (batch,)
    # Broadcast mask over (num_layers, batch, hidden_size)
    mask_3d = mask.unsqueeze(0).unsqueeze(-1)  # (1, batch, 1)
    h_new = self.h.masked_fill(mask_3d, 0.0)
    c_new = self.c.masked_fill(mask_3d, 0.0)
    return LSTMHiddenState(h=h_new, c=c_new)

to(device)

Move states to device.

Source code in models/recurrent_policy.py
def to(self, device: torch.device) -> "LSTMHiddenState":
    """Move states to *device*."""
    return LSTMHiddenState(h=self.h.to(device), c=self.c.to(device))

zeros(num_layers, hidden_size, batch_size=1, device=None) classmethod

Create a zero-initialised hidden state.

Parameters:

Name Type Description Default
num_layers int

Number of LSTM layers.

required
hidden_size int

LSTM hidden dimension.

required
batch_size int

Batch dimension.

1
device Optional[device]

Target device (defaults to CPU).

None
Source code in models/recurrent_policy.py
@classmethod
def zeros(
    cls,
    num_layers: int,
    hidden_size: int,
    batch_size: int = 1,
    device: Optional[torch.device] = None,
) -> "LSTMHiddenState":
    """Create a zero-initialised hidden state.

    Parameters
    ----------
    num_layers:
        Number of LSTM layers.
    hidden_size:
        LSTM hidden dimension.
    batch_size:
        Batch dimension.
    device:
        Target device (defaults to CPU).
    """
    if device is None:
        device = torch.device("cpu")
    shape = (num_layers, batch_size, hidden_size)
    return cls(
        h=torch.zeros(*shape, device=device),
        c=torch.zeros(*shape, device=device),
    )

models.recurrent_policy.RecurrentRolloutBuffer

On-policy rollout buffer that stores LSTM hidden states for BPTT.

Stores per-step hidden states alongside the usual rollout data so that PPO updates can re-run the LSTM through consecutive sequences (truncated BPTT) starting from the exact hidden state present during collection.

Note

This buffer does not automatically reset LSTM hidden states at episode boundaries. When done=True at step t, the caller is responsible for providing an appropriately reset (typically zeroed) hidden state for step t+1 via :meth:~models.recurrent_policy.RecurrentActorCriticPolicy.initial_state.

Parameters:

Name Type Description Default
n_steps int

Number of environment steps per rollout.

required
max_entities int

Maximum number of entity tokens per step (pad shorter observations).

required
token_dim int

Entity token dimensionality.

ENTITY_TOKEN_DIM
action_dim int

Action space dimensionality.

3
lstm_hidden_size int

LSTM hidden state dimensionality.

128
lstm_num_layers int

Number of LSTM layers.

1
gamma float

Discount factor for GAE.

0.99
gae_lambda float

GAE smoothing parameter λ.

0.95
Source code in models/recurrent_policy.py
class RecurrentRolloutBuffer:
    """On-policy rollout buffer that stores LSTM hidden states for BPTT.

    Stores per-step hidden states alongside the usual rollout data so that PPO
    updates can re-run the LSTM through consecutive sequences (truncated BPTT)
    starting from the exact hidden state present during collection.

    Note
    ----
    This buffer does not automatically reset LSTM hidden states at episode
    boundaries.  When ``done=True`` at step *t*, the caller is responsible for
    providing an appropriately reset (typically zeroed) hidden state for
    step *t+1* via :meth:`~models.recurrent_policy.RecurrentActorCriticPolicy.initial_state`.

    Parameters
    ----------
    n_steps:
        Number of environment steps per rollout.
    max_entities:
        Maximum number of entity tokens per step (pad shorter observations).
    token_dim:
        Entity token dimensionality.
    action_dim:
        Action space dimensionality.
    lstm_hidden_size:
        LSTM hidden state dimensionality.
    lstm_num_layers:
        Number of LSTM layers.
    gamma:
        Discount factor for GAE.
    gae_lambda:
        GAE smoothing parameter λ.
    """

    def __init__(
        self,
        n_steps: int,
        max_entities: int,
        token_dim: int = ENTITY_TOKEN_DIM,
        action_dim: int = 3,
        lstm_hidden_size: int = 128,
        lstm_num_layers: int = 1,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
    ) -> None:
        self.n_steps = n_steps
        self.max_entities = max_entities
        self.token_dim = token_dim
        self.action_dim = action_dim
        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_num_layers = lstm_num_layers
        self.gamma = gamma
        self.gae_lambda = gae_lambda

        self._reset()

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _reset(self) -> None:
        """Zero all storage arrays and reset the write pointer."""
        T = self.n_steps
        N = self.max_entities
        A = self.action_dim
        H = self.lstm_hidden_size
        L = self.lstm_num_layers

        # Main transition data
        self.tokens = np.zeros((T, N, self.token_dim), dtype=np.float32)
        self.pad_masks = np.zeros((T, N), dtype=bool)
        self.actions = np.zeros((T, A), dtype=np.float32)
        self.log_probs = np.zeros(T, dtype=np.float32)
        self.rewards = np.zeros(T, dtype=np.float32)
        self.dones = np.zeros(T, dtype=np.float32)
        self.values = np.zeros(T, dtype=np.float32)

        # LSTM hidden states at the *start* of each step (before processing)
        self.hx_h = np.zeros((T, L, H), dtype=np.float32)
        self.hx_c = np.zeros((T, L, H), dtype=np.float32)

        # Filled by compute_returns_and_advantages
        self.advantages = np.zeros(T, dtype=np.float32)
        self.returns = np.zeros(T, dtype=np.float32)

        self._ptr = 0
        self._full = False

    def reset(self) -> None:
        """Public reset — call before each new rollout collection."""
        self._reset()

    # ------------------------------------------------------------------
    # Data collection
    # ------------------------------------------------------------------

    def add(
        self,
        tokens: np.ndarray,
        hx: LSTMHiddenState,
        action: np.ndarray,
        log_prob: float,
        reward: float,
        done: bool,
        value: float,
        pad_mask: Optional[np.ndarray] = None,
    ) -> None:
        """Store one environment transition.

        Parameters
        ----------
        tokens:
            Entity tokens of shape ``(N_obs, token_dim)`` — the observation
            for this step.  Automatically zero-padded to ``max_entities``.
        hx:
            LSTM hidden state **at the start of this step** (before feeding
            *tokens* through the encoder).
        action:
            Action taken, shape ``(action_dim,)``.
        log_prob:
            Log-probability of *action* under the collection policy.
        reward:
            Scalar reward received.
        done:
            Whether the episode ended after this step.
        value:
            Critic value estimate for this step.
        pad_mask:
            Boolean padding mask of shape ``(N_obs,)``.  When omitted, valid
            positions (``[:n_obs]``) are left unmasked (``False``) and any
            remaining padded positions (``[n_obs:]``) are set to ``True``
            (ignored by attention).
        """
        if self._full:
            raise RuntimeError(
                "RecurrentRolloutBuffer is full — call reset() before adding."
            )
        t = self._ptr

        # Store tokens (with zero-padding to max_entities)
        n_obs = min(tokens.shape[0], self.max_entities)
        self.tokens[t, :n_obs] = tokens[:n_obs]
        if pad_mask is not None:
            self.pad_masks[t, :n_obs] = pad_mask[:n_obs]
            self.pad_masks[t, n_obs:] = True  # pad the rest
        else:
            self.pad_masks[t, n_obs:] = True

        # LSTM hidden state at start of this step (squeeze batch dim = 1).
        # NOTE: RecurrentRolloutBuffer currently assumes a single environment
        # (batch size = 1) for the LSTM hidden state. If a batched hidden state
        # is passed in (e.g., from multiple parallel envs), we raise explicitly
        # to avoid silently dropping all but the first environment.
        if hx.h.dim() != 3 or hx.c.dim() != 3:
            raise ValueError(
                f"Expected hx.h and hx.c to have 3 dimensions "
                f"(num_layers, batch_size, hidden_size); "
                f"got hx.h.dim()={hx.h.dim()}, hx.c.dim()={hx.c.dim()}."
            )
        if hx.h.size(1) != 1 or hx.c.size(1) != 1:
            raise ValueError(
                "RecurrentRolloutBuffer.add currently supports only a single "
                "environment (batch_size=1) for LSTM hidden state. "
                f"Got hx.h.shape={tuple(hx.h.shape)}, "
                f"hx.c.shape={tuple(hx.c.shape)}."
            )
        self.hx_h[t] = hx.h[:, 0, :].detach().cpu().numpy()
        self.hx_c[t] = hx.c[:, 0, :].detach().cpu().numpy()

        self.actions[t] = action
        self.log_probs[t] = log_prob
        self.rewards[t] = reward
        self.dones[t] = float(done)
        self.values[t] = value

        self._ptr += 1
        if self._ptr == self.n_steps:
            self._full = True

    # ------------------------------------------------------------------
    # GAE return computation
    # ------------------------------------------------------------------

    def compute_returns_and_advantages(
        self,
        last_value: float,
        last_done: bool,
    ) -> None:
        """Compute GAE advantages and discounted returns.

        Must be called once the buffer is full (all *n_steps* transitions
        have been added).

        Parameters
        ----------
        last_value:
            Critic value estimate for the state **after** the last stored step
            (bootstrap value).  Set to ``0.0`` when the last step was terminal.
        last_done:
            Whether the last step was terminal.
        """
        if not self._full and self._ptr != self.n_steps:
            raise RuntimeError(
                "Buffer is not yet full — add all transitions before computing returns."
            )
        gae = 0.0
        next_value = last_value

        for t in reversed(range(self.n_steps)):
            # Use the stored done flag for the *current* step to decide whether
            # to bootstrap from the next value.  Using dones[t] directly avoids
            # the off-by-one error that would arise from carrying next_done
            # across iterations.
            not_terminal = 1.0 - self.dones[t]
            delta = (
                self.rewards[t]
                + self.gamma * next_value * not_terminal
                - self.values[t]
            )
            gae = delta + self.gamma * self.gae_lambda * not_terminal * gae
            self.advantages[t] = gae
            self.returns[t] = gae + self.values[t]

            next_value = self.values[t]

    # ------------------------------------------------------------------
    # Sequence batching for BPTT
    # ------------------------------------------------------------------

    def get_sequences(
        self,
        seq_len: int,
        device: torch.device,
        normalize_advantages: bool = True,
    ) -> List[Dict[str, torch.Tensor]]:
        """Split the buffer into non-overlapping sequences for BPTT.

        Each returned batch is a dict with keys:

        * ``"tokens"``      — shape ``(n_seqs, seq_len, max_entities, token_dim)``
        * ``"pad_masks"``   — shape ``(n_seqs, seq_len, max_entities)``
        * ``"hx_h"``        — shape ``(lstm_num_layers, n_seqs, lstm_hidden_size)``
        * ``"hx_c"``        — shape ``(lstm_num_layers, n_seqs, lstm_hidden_size)``
        * ``"actions"``     — shape ``(n_seqs, seq_len, action_dim)``
        * ``"log_probs"``   — shape ``(n_seqs, seq_len)``
        * ``"advantages"``  — shape ``(n_seqs, seq_len)``
        * ``"returns"``     — shape ``(n_seqs, seq_len)``
        * ``"values"``      — shape ``(n_seqs, seq_len)``

        The initial hidden states ``hx_h`` / ``hx_c`` are taken from the
        *first* step of each sequence.

        Parameters
        ----------
        seq_len:
            Length of each sub-sequence.  Must evenly divide ``n_steps``.
        device:
            Destination device for returned tensors.
        normalize_advantages:
            Normalise advantages to zero mean, unit variance (recommended).

        Returns
        -------
        batches : list of dicts (each dict is one sequence batch)
        """
        if self.n_steps % seq_len != 0:
            raise ValueError(
                f"seq_len={seq_len} must evenly divide n_steps={self.n_steps}."
            )

        adv = self.advantages.copy()
        if normalize_advantages:
            adv_std = adv.std() + _NORMALIZATION_EPSILON
            adv = (adv - adv.mean()) / adv_std

        n_seqs = self.n_steps // seq_len
        batches: List[Dict[str, torch.Tensor]] = []

        for i in range(n_seqs):
            start = i * seq_len
            end = start + seq_len
            sl = slice(start, end)

            # Hidden states from the first step of this sequence
            hx_h = self.hx_h[start]  # (L, H)
            hx_c = self.hx_c[start]  # (L, H)

            batch = {
                "tokens": torch.as_tensor(
                    self.tokens[sl], dtype=torch.float32, device=device
                ).unsqueeze(0),  # (1, seq_len, N, token_dim)
                "pad_masks": torch.as_tensor(
                    self.pad_masks[sl], dtype=torch.bool, device=device
                ).unsqueeze(0),  # (1, seq_len, N)
                "hx_h": torch.as_tensor(
                    hx_h, dtype=torch.float32, device=device
                ).unsqueeze(1),  # (L, 1, H)
                "hx_c": torch.as_tensor(
                    hx_c, dtype=torch.float32, device=device
                ).unsqueeze(1),  # (L, 1, H)
                "actions": torch.as_tensor(
                    self.actions[sl], dtype=torch.float32, device=device
                ).unsqueeze(0),  # (1, seq_len, action_dim)
                "log_probs": torch.as_tensor(
                    self.log_probs[sl], dtype=torch.float32, device=device
                ).unsqueeze(0),  # (1, seq_len)
                "advantages": torch.as_tensor(
                    adv[sl], dtype=torch.float32, device=device
                ).unsqueeze(0),  # (1, seq_len)
                "returns": torch.as_tensor(
                    self.returns[sl], dtype=torch.float32, device=device
                ).unsqueeze(0),  # (1, seq_len)
                "values": torch.as_tensor(
                    self.values[sl], dtype=torch.float32, device=device
                ).unsqueeze(0),  # (1, seq_len)
            }
            batches.append(batch)

        return batches

    # ------------------------------------------------------------------
    # Diagnostics
    # ------------------------------------------------------------------

    def memory_bytes(self) -> int:
        """Approximate memory usage of stored arrays in bytes."""
        arrays = [
            self.tokens, self.pad_masks, self.actions,
            self.log_probs, self.rewards, self.dones, self.values,
            self.hx_h, self.hx_c, self.advantages, self.returns,
        ]
        return sum(a.nbytes for a in arrays)

    def __len__(self) -> int:
        return self._ptr

add(tokens, hx, action, log_prob, reward, done, value, pad_mask=None)

Store one environment transition.

Parameters:

Name Type Description Default
tokens ndarray

Entity tokens of shape (N_obs, token_dim) — the observation for this step. Automatically zero-padded to max_entities.

required
hx LSTMHiddenState

LSTM hidden state at the start of this step (before feeding tokens through the encoder).

required
action ndarray

Action taken, shape (action_dim,).

required
log_prob float

Log-probability of action under the collection policy.

required
reward float

Scalar reward received.

required
done bool

Whether the episode ended after this step.

required
value float

Critic value estimate for this step.

required
pad_mask Optional[ndarray]

Boolean padding mask of shape (N_obs,). When omitted, valid positions ([:n_obs]) are left unmasked (False) and any remaining padded positions ([n_obs:]) are set to True (ignored by attention).

None
Source code in models/recurrent_policy.py
def add(
    self,
    tokens: np.ndarray,
    hx: LSTMHiddenState,
    action: np.ndarray,
    log_prob: float,
    reward: float,
    done: bool,
    value: float,
    pad_mask: Optional[np.ndarray] = None,
) -> None:
    """Store one environment transition.

    Parameters
    ----------
    tokens:
        Entity tokens of shape ``(N_obs, token_dim)`` — the observation
        for this step.  Automatically zero-padded to ``max_entities``.
    hx:
        LSTM hidden state **at the start of this step** (before feeding
        *tokens* through the encoder).
    action:
        Action taken, shape ``(action_dim,)``.
    log_prob:
        Log-probability of *action* under the collection policy.
    reward:
        Scalar reward received.
    done:
        Whether the episode ended after this step.
    value:
        Critic value estimate for this step.
    pad_mask:
        Boolean padding mask of shape ``(N_obs,)``.  When omitted, valid
        positions (``[:n_obs]``) are left unmasked (``False``) and any
        remaining padded positions (``[n_obs:]``) are set to ``True``
        (ignored by attention).
    """
    if self._full:
        raise RuntimeError(
            "RecurrentRolloutBuffer is full — call reset() before adding."
        )
    t = self._ptr

    # Store tokens (with zero-padding to max_entities)
    n_obs = min(tokens.shape[0], self.max_entities)
    self.tokens[t, :n_obs] = tokens[:n_obs]
    if pad_mask is not None:
        self.pad_masks[t, :n_obs] = pad_mask[:n_obs]
        self.pad_masks[t, n_obs:] = True  # pad the rest
    else:
        self.pad_masks[t, n_obs:] = True

    # LSTM hidden state at start of this step (squeeze batch dim = 1).
    # NOTE: RecurrentRolloutBuffer currently assumes a single environment
    # (batch size = 1) for the LSTM hidden state. If a batched hidden state
    # is passed in (e.g., from multiple parallel envs), we raise explicitly
    # to avoid silently dropping all but the first environment.
    if hx.h.dim() != 3 or hx.c.dim() != 3:
        raise ValueError(
            f"Expected hx.h and hx.c to have 3 dimensions "
            f"(num_layers, batch_size, hidden_size); "
            f"got hx.h.dim()={hx.h.dim()}, hx.c.dim()={hx.c.dim()}."
        )
    if hx.h.size(1) != 1 or hx.c.size(1) != 1:
        raise ValueError(
            "RecurrentRolloutBuffer.add currently supports only a single "
            "environment (batch_size=1) for LSTM hidden state. "
            f"Got hx.h.shape={tuple(hx.h.shape)}, "
            f"hx.c.shape={tuple(hx.c.shape)}."
        )
    self.hx_h[t] = hx.h[:, 0, :].detach().cpu().numpy()
    self.hx_c[t] = hx.c[:, 0, :].detach().cpu().numpy()

    self.actions[t] = action
    self.log_probs[t] = log_prob
    self.rewards[t] = reward
    self.dones[t] = float(done)
    self.values[t] = value

    self._ptr += 1
    if self._ptr == self.n_steps:
        self._full = True

compute_returns_and_advantages(last_value, last_done)

Compute GAE advantages and discounted returns.

Must be called once the buffer is full (all n_steps transitions have been added).

Parameters:

Name Type Description Default
last_value float

Critic value estimate for the state after the last stored step (bootstrap value). Set to 0.0 when the last step was terminal.

required
last_done bool

Whether the last step was terminal.

required
Source code in models/recurrent_policy.py
def compute_returns_and_advantages(
    self,
    last_value: float,
    last_done: bool,
) -> None:
    """Compute GAE advantages and discounted returns.

    Must be called once the buffer is full (all *n_steps* transitions
    have been added).

    Parameters
    ----------
    last_value:
        Critic value estimate for the state **after** the last stored step
        (bootstrap value).  Set to ``0.0`` when the last step was terminal.
    last_done:
        Whether the last step was terminal.
    """
    if not self._full and self._ptr != self.n_steps:
        raise RuntimeError(
            "Buffer is not yet full — add all transitions before computing returns."
        )
    gae = 0.0
    next_value = last_value

    for t in reversed(range(self.n_steps)):
        # Use the stored done flag for the *current* step to decide whether
        # to bootstrap from the next value.  Using dones[t] directly avoids
        # the off-by-one error that would arise from carrying next_done
        # across iterations.
        not_terminal = 1.0 - self.dones[t]
        delta = (
            self.rewards[t]
            + self.gamma * next_value * not_terminal
            - self.values[t]
        )
        gae = delta + self.gamma * self.gae_lambda * not_terminal * gae
        self.advantages[t] = gae
        self.returns[t] = gae + self.values[t]

        next_value = self.values[t]

get_sequences(seq_len, device, normalize_advantages=True)

Split the buffer into non-overlapping sequences for BPTT.

Each returned batch is a dict with keys:

  • "tokens" — shape (n_seqs, seq_len, max_entities, token_dim)
  • "pad_masks" — shape (n_seqs, seq_len, max_entities)
  • "hx_h" — shape (lstm_num_layers, n_seqs, lstm_hidden_size)
  • "hx_c" — shape (lstm_num_layers, n_seqs, lstm_hidden_size)
  • "actions" — shape (n_seqs, seq_len, action_dim)
  • "log_probs" — shape (n_seqs, seq_len)
  • "advantages" — shape (n_seqs, seq_len)
  • "returns" — shape (n_seqs, seq_len)
  • "values" — shape (n_seqs, seq_len)

The initial hidden states hx_h / hx_c are taken from the first step of each sequence.

Parameters:

Name Type Description Default
seq_len int

Length of each sub-sequence. Must evenly divide n_steps.

required
device device

Destination device for returned tensors.

required
normalize_advantages bool

Normalise advantages to zero mean, unit variance (recommended).

True

Returns:

Name Type Description
batches list of dicts (each dict is one sequence batch)
Source code in models/recurrent_policy.py
def get_sequences(
    self,
    seq_len: int,
    device: torch.device,
    normalize_advantages: bool = True,
) -> List[Dict[str, torch.Tensor]]:
    """Split the buffer into non-overlapping sequences for BPTT.

    Each returned batch is a dict with keys:

    * ``"tokens"``      — shape ``(n_seqs, seq_len, max_entities, token_dim)``
    * ``"pad_masks"``   — shape ``(n_seqs, seq_len, max_entities)``
    * ``"hx_h"``        — shape ``(lstm_num_layers, n_seqs, lstm_hidden_size)``
    * ``"hx_c"``        — shape ``(lstm_num_layers, n_seqs, lstm_hidden_size)``
    * ``"actions"``     — shape ``(n_seqs, seq_len, action_dim)``
    * ``"log_probs"``   — shape ``(n_seqs, seq_len)``
    * ``"advantages"``  — shape ``(n_seqs, seq_len)``
    * ``"returns"``     — shape ``(n_seqs, seq_len)``
    * ``"values"``      — shape ``(n_seqs, seq_len)``

    The initial hidden states ``hx_h`` / ``hx_c`` are taken from the
    *first* step of each sequence.

    Parameters
    ----------
    seq_len:
        Length of each sub-sequence.  Must evenly divide ``n_steps``.
    device:
        Destination device for returned tensors.
    normalize_advantages:
        Normalise advantages to zero mean, unit variance (recommended).

    Returns
    -------
    batches : list of dicts (each dict is one sequence batch)
    """
    if self.n_steps % seq_len != 0:
        raise ValueError(
            f"seq_len={seq_len} must evenly divide n_steps={self.n_steps}."
        )

    adv = self.advantages.copy()
    if normalize_advantages:
        adv_std = adv.std() + _NORMALIZATION_EPSILON
        adv = (adv - adv.mean()) / adv_std

    n_seqs = self.n_steps // seq_len
    batches: List[Dict[str, torch.Tensor]] = []

    for i in range(n_seqs):
        start = i * seq_len
        end = start + seq_len
        sl = slice(start, end)

        # Hidden states from the first step of this sequence
        hx_h = self.hx_h[start]  # (L, H)
        hx_c = self.hx_c[start]  # (L, H)

        batch = {
            "tokens": torch.as_tensor(
                self.tokens[sl], dtype=torch.float32, device=device
            ).unsqueeze(0),  # (1, seq_len, N, token_dim)
            "pad_masks": torch.as_tensor(
                self.pad_masks[sl], dtype=torch.bool, device=device
            ).unsqueeze(0),  # (1, seq_len, N)
            "hx_h": torch.as_tensor(
                hx_h, dtype=torch.float32, device=device
            ).unsqueeze(1),  # (L, 1, H)
            "hx_c": torch.as_tensor(
                hx_c, dtype=torch.float32, device=device
            ).unsqueeze(1),  # (L, 1, H)
            "actions": torch.as_tensor(
                self.actions[sl], dtype=torch.float32, device=device
            ).unsqueeze(0),  # (1, seq_len, action_dim)
            "log_probs": torch.as_tensor(
                self.log_probs[sl], dtype=torch.float32, device=device
            ).unsqueeze(0),  # (1, seq_len)
            "advantages": torch.as_tensor(
                adv[sl], dtype=torch.float32, device=device
            ).unsqueeze(0),  # (1, seq_len)
            "returns": torch.as_tensor(
                self.returns[sl], dtype=torch.float32, device=device
            ).unsqueeze(0),  # (1, seq_len)
            "values": torch.as_tensor(
                self.values[sl], dtype=torch.float32, device=device
            ).unsqueeze(0),  # (1, seq_len)
        }
        batches.append(batch)

    return batches

memory_bytes()

Approximate memory usage of stored arrays in bytes.

Source code in models/recurrent_policy.py
def memory_bytes(self) -> int:
    """Approximate memory usage of stored arrays in bytes."""
    arrays = [
        self.tokens, self.pad_masks, self.actions,
        self.log_probs, self.rewards, self.dones, self.values,
        self.hx_h, self.hx_c, self.advantages, self.returns,
    ]
    return sum(a.nbytes for a in arrays)

reset()

Public reset — call before each new rollout collection.

Source code in models/recurrent_policy.py
def reset(self) -> None:
    """Public reset — call before each new rollout collection."""
    self._reset()

WFM-1 foundation model

models.wfm1.WFM1Policy

Bases: Module

WFM-1 hierarchical transformer policy.

A single policy that operates across battalion, brigade, division, and corps echelons. Multi-echelon information is fused via cross-echelon attention. A lightweight ScenarioCard FiLM adapter conditions the policy on scenario metadata, enabling efficient fine-tuning.

Architecture::

For each active echelon e  {battalion, brigade, division, corps}:
  tokens_e (B, N_e, token_dim)
        
  EchelonEncoder(echelon=e)    enc_e (B, d_model)

Stack: echelon_encs = [enc_e₁, enc_e₂, ]  shape (B, E, d_model)
        
CrossEchelonTransformer    fused (B, d_model)
        
FiLM(ScenarioCard) :  fused  fused × γ + β
        
actor head    Gaussian action distribution
critic head   scalar value

Parameters:

Name Type Description Default
token_dim int

Entity token dimensionality. Defaults to :data:~models.entity_encoder.ENTITY_TOKEN_DIM.

ENTITY_TOKEN_DIM
action_dim int

Continuous action space dimensionality.

3
d_model int

Transformer hidden dimension.

128
n_heads int

Attention heads for both echelon encoder and cross-echelon transformer.

8
n_echelon_layers int

Transformer depth for each :class:EchelonEncoder.

4
n_cross_layers int

Transformer depth for :class:CrossEchelonTransformer.

2
actor_hidden_sizes Tuple[int, ...]

MLP hidden sizes for the actor head.

(256, 128)
critic_hidden_sizes Tuple[int, ...]

MLP hidden sizes for the critic head.

(256, 128)
dropout float

Dropout probability (transformer only).

0.0
use_spatial_pe bool

Enable 2-D Fourier positional encoding in the echelon encoders.

True
share_echelon_encoders bool

When True (default), all four echelon levels share a single :class:EchelonEncoder instance. When False each echelon has independent weights.

True
card_hidden_size int

Width of the FiLM adapter MLP.

64
Source code in models/wfm1.py
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
class WFM1Policy(nn.Module):
    """WFM-1 hierarchical transformer policy.

    A single policy that operates across battalion, brigade, division, and
    corps echelons.  Multi-echelon information is fused via cross-echelon
    attention.  A lightweight ScenarioCard FiLM adapter conditions the policy
    on scenario metadata, enabling efficient fine-tuning.

    Architecture::

        For each active echelon e ∈ {battalion, brigade, division, corps}:
          tokens_e (B, N_e, token_dim)

          EchelonEncoder(echelon=e)  →  enc_e (B, d_model)

        Stack: echelon_encs = [enc_e₁, enc_e₂, …]  shape (B, E, d_model)

        CrossEchelonTransformer  →  fused (B, d_model)

        FiLM(ScenarioCard) :  fused ← fused × γ + β

        actor head  →  Gaussian action distribution
        critic head →  scalar value

    Parameters
    ----------
    token_dim:
        Entity token dimensionality.  Defaults to :data:`~models.entity_encoder.ENTITY_TOKEN_DIM`.
    action_dim:
        Continuous action space dimensionality.
    d_model:
        Transformer hidden dimension.
    n_heads:
        Attention heads for both echelon encoder and cross-echelon transformer.
    n_echelon_layers:
        Transformer depth for each :class:`EchelonEncoder`.
    n_cross_layers:
        Transformer depth for :class:`CrossEchelonTransformer`.
    actor_hidden_sizes:
        MLP hidden sizes for the actor head.
    critic_hidden_sizes:
        MLP hidden sizes for the critic head.
    dropout:
        Dropout probability (transformer only).
    use_spatial_pe:
        Enable 2-D Fourier positional encoding in the echelon encoders.
    share_echelon_encoders:
        When ``True`` (default), all four echelon levels share a single
        :class:`EchelonEncoder` instance.  When ``False`` each echelon has
        independent weights.
    card_hidden_size:
        Width of the FiLM adapter MLP.
    """

    def __init__(
        self,
        token_dim: int = ENTITY_TOKEN_DIM,
        action_dim: int = 3,
        d_model: int = 128,
        n_heads: int = 8,
        n_echelon_layers: int = 4,
        n_cross_layers: int = 2,
        actor_hidden_sizes: Tuple[int, ...] = (256, 128),
        critic_hidden_sizes: Tuple[int, ...] = (256, 128),
        dropout: float = 0.0,
        use_spatial_pe: bool = True,
        share_echelon_encoders: bool = True,
        card_hidden_size: int = 64,
    ) -> None:
        super().__init__()
        self.token_dim = token_dim
        self.action_dim = action_dim
        self.d_model = d_model
        self.share_echelon_encoders = share_echelon_encoders
        self.dropout = dropout

        # --- Echelon encoders -------------------------------------------------
        if share_echelon_encoders:
            _shared = EchelonEncoder(
                token_dim=token_dim,
                d_model=d_model,
                n_heads=n_heads,
                n_layers=n_echelon_layers,
                dropout=dropout,
                use_spatial_pe=use_spatial_pe,
                use_echelon_embedding=True,
            )
            self.echelon_encoders = nn.ModuleList([_shared] * _N_ECHELONS)
        else:
            self.echelon_encoders = nn.ModuleList([
                EchelonEncoder(
                    token_dim=token_dim,
                    d_model=d_model,
                    n_heads=n_heads,
                    n_layers=n_echelon_layers,
                    dropout=dropout,
                    use_spatial_pe=use_spatial_pe,
                    use_echelon_embedding=True,
                )
                for _ in range(_N_ECHELONS)
            ])

        # --- Cross-echelon transformer ----------------------------------------
        self.cross_echelon = CrossEchelonTransformer(
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_cross_layers,
            dropout=dropout,
        )

        # --- Scenario card FiLM adapter ---------------------------------------
        self.card_encoder = _ScenarioCardEncoder(
            card_raw_dim=_SCENARIO_CARD_RAW_DIM,
            d_model=d_model,
            hidden_size=card_hidden_size,
        )

        # --- Actor / critic heads ---------------------------------------------
        self.actor_head = _build_mlp(d_model, actor_hidden_sizes, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))
        self.critic_head = _build_mlp(d_model, critic_hidden_sizes, 1)

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _encode(
        self,
        tokens_per_echelon: Dict[int, torch.Tensor],
        pad_masks: Optional[Dict[int, torch.Tensor]] = None,
        card: Optional[ScenarioCard] = None,
        card_vec: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Produce the FiLM-modulated fused encoding.

        Parameters
        ----------
        tokens_per_echelon:
            Mapping from echelon id → entity token tensor ``(B, N_e, token_dim)``.
            At least one echelon must be provided.
        pad_masks:
            Optional mapping from echelon id → padding mask ``(B, N_e)``.
        card:
            :class:`ScenarioCard` to condition on.  Mutually exclusive with
            ``card_vec``; if both are ``None`` the FiLM adapter produces
            identity transforms (γ=1, β=0).
        card_vec:
            Pre-computed scenario card tensor ``(B, card_raw_dim)`` or
            ``(card_raw_dim,)``.  Takes priority over ``card``.

        Returns
        -------
        fused : torch.Tensor — shape ``(B, d_model)``
        """
        if not tokens_per_echelon:
            raise ValueError("tokens_per_echelon must contain at least one echelon.")

        pad_masks = pad_masks or {}
        echelon_ids_list: List[int] = sorted(tokens_per_echelon.keys())

        # Encode each echelon
        echelon_encs: List[torch.Tensor] = []
        for eid in echelon_ids_list:
            enc = self.echelon_encoders[eid](
                tokens_per_echelon[eid],
                echelon=eid,
                pad_mask=pad_masks.get(eid),
            )  # (B, d_model)
            echelon_encs.append(enc)

        # Stack → (B, E, d_model) for cross-echelon transformer
        stacked = torch.stack(echelon_encs, dim=1)  # (B, E, d_model)
        echelon_ids_t = torch.tensor(
            echelon_ids_list,
            device=stacked.device,
            dtype=torch.long,
        )  # (E,)
        fused = self.cross_echelon(stacked, echelon_ids_t)  # (B, d_model)

        # FiLM modulation from scenario card
        if card_vec is not None:
            if card_vec.device != fused.device:
                card_vec = card_vec.to(fused.device)
            gamma, beta = self.card_encoder(card_vec)
        elif card is not None:
            cv = card.to_tensor(device=fused.device)
            gamma, beta = self.card_encoder(cv)
        else:
            # Identity FiLM: γ=1, β=0 (no modulation)
            gamma = torch.ones(fused.shape[0], self.d_model, device=fused.device)
            beta = torch.zeros(fused.shape[0], self.d_model, device=fused.device)

        fused = fused * gamma + beta  # (B, d_model)
        return fused

    # ------------------------------------------------------------------
    # Single-echelon convenience shortcut
    # ------------------------------------------------------------------

    def _encode_single(
        self,
        tokens: torch.Tensor,
        echelon: int = ECHELON_BATTALION,
        pad_mask: Optional[torch.Tensor] = None,
        card: Optional[ScenarioCard] = None,
        card_vec: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Encode tokens from a single echelon (most common use-case)."""
        return self._encode(
            tokens_per_echelon={echelon: tokens},
            pad_masks={echelon: pad_mask} if pad_mask is not None else None,
            card=card,
            card_vec=card_vec,
        )

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    @torch.no_grad()
    def act(
        self,
        tokens: Optional[torch.Tensor],
        pad_mask: Optional[torch.Tensor] = None,
        echelon: int = ECHELON_BATTALION,
        card: Optional[ScenarioCard] = None,
        card_vec: Optional[torch.Tensor] = None,
        tokens_per_echelon: Optional[Dict[int, torch.Tensor]] = None,
        pad_masks: Optional[Dict[int, torch.Tensor]] = None,
        deterministic: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample actions (no gradient).

        Can be called with a single-echelon ``tokens`` tensor or with a
        full ``tokens_per_echelon`` dict for multi-echelon inputs.

        Parameters
        ----------
        tokens:
            Single-echelon entity tokens ``(B, N, token_dim)``.
            Ignored when ``tokens_per_echelon`` is provided.
        pad_mask:
            Padding mask for ``tokens``.
        echelon:
            Active echelon level for single-echelon mode.
        card:
            Scenario conditioning card.
        card_vec:
            Pre-computed scenario card tensor (alternative to ``card``).
        tokens_per_echelon:
            Multi-echelon input dict; overrides ``tokens`` when given.
        pad_masks:
            Padding masks for multi-echelon mode.
        deterministic:
            When ``True``, return the distribution mean instead of a sample.

        Returns
        -------
        actions   : torch.Tensor — shape ``(B, action_dim)``
        log_probs : torch.Tensor — shape ``(B,)``
        """
        if tokens_per_echelon is not None:
            fused = self._encode(
                tokens_per_echelon, pad_masks, card=card, card_vec=card_vec
            )
        else:
            fused = self._encode_single(tokens, echelon, pad_mask, card, card_vec)

        mean = self.actor_head(fused)
        std = self.log_std.exp().expand_as(mean)
        dist = Normal(mean, std)
        actions = mean if deterministic else dist.sample()
        log_probs = dist.log_prob(actions).sum(dim=-1)
        return actions, log_probs

    def get_value(
        self,
        tokens: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None,
        echelon: int = ECHELON_BATTALION,
        card: Optional[ScenarioCard] = None,
        card_vec: Optional[torch.Tensor] = None,
        tokens_per_echelon: Optional[Dict[int, torch.Tensor]] = None,
        pad_masks: Optional[Dict[int, torch.Tensor]] = None,
    ) -> torch.Tensor:
        """Compute scalar value estimates.

        Returns
        -------
        values : torch.Tensor — shape ``(B,)``
        """
        if tokens_per_echelon is not None:
            fused = self._encode(
                tokens_per_echelon, pad_masks, card=card, card_vec=card_vec
            )
        else:
            fused = self._encode_single(tokens, echelon, pad_mask, card, card_vec)

        return self.critic_head(fused).squeeze(-1)  # (B,)

    def evaluate_actions(
        self,
        tokens: torch.Tensor,
        actions: torch.Tensor,
        pad_mask: Optional[torch.Tensor] = None,
        echelon: int = ECHELON_BATTALION,
        card: Optional[ScenarioCard] = None,
        card_vec: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Evaluate log-probs, entropy, and values for given actions.

        Used during PPO update.

        Returns
        -------
        log_probs : torch.Tensor — shape ``(B,)``
        entropy   : torch.Tensor — shape ``(B,)``
        values    : torch.Tensor — shape ``(B,)``
        """
        fused = self._encode_single(tokens, echelon, pad_mask, card, card_vec)

        mean = self.actor_head(fused)
        std = self.log_std.exp().expand_as(mean)
        dist = Normal(mean, std)

        log_probs = dist.log_prob(actions).sum(dim=-1)
        entropy = dist.entropy().sum(dim=-1)
        values = self.critic_head(fused).squeeze(-1)
        return log_probs, entropy, values

    # ------------------------------------------------------------------
    # Fine-tuning adapter API
    # ------------------------------------------------------------------

    def adapter_parameters(self) -> List[nn.Parameter]:
        """Return only the FiLM adapter parameters.

        Call this to get a parameter group for lightweight scenario-specific
        fine-tuning (adapter-only gradient updates leave the base transformer
        frozen).

        Returns
        -------
        list of :class:`torch.nn.Parameter`
        """
        return list(self.card_encoder.parameters())

    def base_parameters(self) -> List[nn.Parameter]:
        """Return all non-adapter (base model) parameters."""
        adapter_ids = {id(p) for p in self.adapter_parameters()}
        return [p for p in self.parameters() if id(p) not in adapter_ids]

    def freeze_base(self) -> None:
        """Freeze all base-model parameters; only the adapter remains trainable."""
        for p in self.base_parameters():
            p.requires_grad_(False)

    def unfreeze_base(self) -> None:
        """Unfreeze all base-model parameters."""
        for p in self.parameters():
            p.requires_grad_(True)

    # ------------------------------------------------------------------
    # Checkpoint utilities
    # ------------------------------------------------------------------

    def save_checkpoint(self, path: str | Path) -> Path:
        """Save model state dict to *path* (``.pt``)."""
        out = Path(path)
        out.parent.mkdir(parents=True, exist_ok=True)
        torch.save(
            {
                "state_dict": self.state_dict(),
                "config": {
                    "token_dim": self.token_dim,
                    "action_dim": self.action_dim,
                    "d_model": self.d_model,
                    "n_heads": self.echelon_encoders[0].encoder.n_heads,
                    "n_echelon_layers": self.echelon_encoders[0].encoder.n_layers,
                    "n_cross_layers": self.cross_echelon.transformer.num_layers,
                    "actor_hidden_sizes": tuple(
                        layer.out_features
                        for layer in self.actor_head
                        if isinstance(layer, nn.Linear)
                    )[:-1],
                    "critic_hidden_sizes": tuple(
                        layer.out_features
                        for layer in self.critic_head
                        if isinstance(layer, nn.Linear)
                    )[:-1],
                    "dropout": self.dropout,
                    "share_echelon_encoders": self.share_echelon_encoders,
                    "card_hidden_size": self.card_encoder.mlp[0].out_features,
                },
            },
            out,
        )
        return out

    @classmethod
    def load_checkpoint(
        cls,
        path: str | Path,
        map_location: Optional[torch.device] = None,
        **kwargs,
    ) -> "WFM1Policy":
        """Load a WFM-1 checkpoint produced by :meth:`save_checkpoint`.

        Extra keyword arguments override the saved configuration.
        """
        ckpt = torch.load(path, map_location=map_location)
        cfg = {**ckpt["config"], **kwargs}
        policy = cls(**cfg)
        policy.load_state_dict(ckpt["state_dict"])
        return policy

    def finetune_loss(
        self,
        batch: Dict[str, torch.Tensor],
    ) -> torch.Tensor:
        """Compute a supervised fine-tuning loss from a demonstration batch.

        The batch dict must contain:
        * ``"tokens"``   — shape ``(B, N, token_dim)``
        * ``"actions"``  — shape ``(B, action_dim)``

        Optional keys:
        * ``"pad_mask"`` — shape ``(B, N)``
        * ``"echelon"``  — scalar int (default: ``ECHELON_BATTALION``)
        * ``"card_vec"`` — shape ``(B, card_raw_dim)`` or ``(card_raw_dim,)``

        Returns
        -------
        loss : torch.Tensor — scalar behaviour-cloning MSE loss
        """
        tokens = batch["tokens"]
        target_actions = batch["actions"]
        pad_mask = batch.get("pad_mask")
        echelon = int(batch["echelon"]) if "echelon" in batch else ECHELON_BATTALION
        card_vec = batch.get("card_vec")

        fused = self._encode_single(tokens, echelon, pad_mask, card_vec=card_vec)
        pred_mean = self.actor_head(fused)
        return nn.functional.mse_loss(pred_mean, target_actions)

act(tokens, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None, tokens_per_echelon=None, pad_masks=None, deterministic=False)

Sample actions (no gradient).

Can be called with a single-echelon tokens tensor or with a full tokens_per_echelon dict for multi-echelon inputs.

Parameters:

Name Type Description Default
tokens Optional[Tensor]

Single-echelon entity tokens (B, N, token_dim). Ignored when tokens_per_echelon is provided.

required
pad_mask Optional[Tensor]

Padding mask for tokens.

None
echelon int

Active echelon level for single-echelon mode.

ECHELON_BATTALION
card Optional[ScenarioCard]

Scenario conditioning card.

None
card_vec Optional[Tensor]

Pre-computed scenario card tensor (alternative to card).

None
tokens_per_echelon Optional[Dict[int, Tensor]]

Multi-echelon input dict; overrides tokens when given.

None
pad_masks Optional[Dict[int, Tensor]]

Padding masks for multi-echelon mode.

None
deterministic bool

When True, return the distribution mean instead of a sample.

False

Returns:

Name Type Description
actions torch.Tensor — shape ``(B, action_dim)``
log_probs torch.Tensor — shape ``(B,)``
Source code in models/wfm1.py
@torch.no_grad()
def act(
    self,
    tokens: Optional[torch.Tensor],
    pad_mask: Optional[torch.Tensor] = None,
    echelon: int = ECHELON_BATTALION,
    card: Optional[ScenarioCard] = None,
    card_vec: Optional[torch.Tensor] = None,
    tokens_per_echelon: Optional[Dict[int, torch.Tensor]] = None,
    pad_masks: Optional[Dict[int, torch.Tensor]] = None,
    deterministic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Sample actions (no gradient).

    Can be called with a single-echelon ``tokens`` tensor or with a
    full ``tokens_per_echelon`` dict for multi-echelon inputs.

    Parameters
    ----------
    tokens:
        Single-echelon entity tokens ``(B, N, token_dim)``.
        Ignored when ``tokens_per_echelon`` is provided.
    pad_mask:
        Padding mask for ``tokens``.
    echelon:
        Active echelon level for single-echelon mode.
    card:
        Scenario conditioning card.
    card_vec:
        Pre-computed scenario card tensor (alternative to ``card``).
    tokens_per_echelon:
        Multi-echelon input dict; overrides ``tokens`` when given.
    pad_masks:
        Padding masks for multi-echelon mode.
    deterministic:
        When ``True``, return the distribution mean instead of a sample.

    Returns
    -------
    actions   : torch.Tensor — shape ``(B, action_dim)``
    log_probs : torch.Tensor — shape ``(B,)``
    """
    if tokens_per_echelon is not None:
        fused = self._encode(
            tokens_per_echelon, pad_masks, card=card, card_vec=card_vec
        )
    else:
        fused = self._encode_single(tokens, echelon, pad_mask, card, card_vec)

    mean = self.actor_head(fused)
    std = self.log_std.exp().expand_as(mean)
    dist = Normal(mean, std)
    actions = mean if deterministic else dist.sample()
    log_probs = dist.log_prob(actions).sum(dim=-1)
    return actions, log_probs

adapter_parameters()

Return only the FiLM adapter parameters.

Call this to get a parameter group for lightweight scenario-specific fine-tuning (adapter-only gradient updates leave the base transformer frozen).

Returns:

Type Description
list of :class:`torch.nn.Parameter`
Source code in models/wfm1.py
def adapter_parameters(self) -> List[nn.Parameter]:
    """Return only the FiLM adapter parameters.

    Call this to get a parameter group for lightweight scenario-specific
    fine-tuning (adapter-only gradient updates leave the base transformer
    frozen).

    Returns
    -------
    list of :class:`torch.nn.Parameter`
    """
    return list(self.card_encoder.parameters())

base_parameters()

Return all non-adapter (base model) parameters.

Source code in models/wfm1.py
def base_parameters(self) -> List[nn.Parameter]:
    """Return all non-adapter (base model) parameters."""
    adapter_ids = {id(p) for p in self.adapter_parameters()}
    return [p for p in self.parameters() if id(p) not in adapter_ids]

evaluate_actions(tokens, actions, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None)

Evaluate log-probs, entropy, and values for given actions.

Used during PPO update.

Returns:

Name Type Description
log_probs torch.Tensor — shape ``(B,)``
entropy torch.Tensor — shape ``(B,)``
values torch.Tensor — shape ``(B,)``
Source code in models/wfm1.py
def evaluate_actions(
    self,
    tokens: torch.Tensor,
    actions: torch.Tensor,
    pad_mask: Optional[torch.Tensor] = None,
    echelon: int = ECHELON_BATTALION,
    card: Optional[ScenarioCard] = None,
    card_vec: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Evaluate log-probs, entropy, and values for given actions.

    Used during PPO update.

    Returns
    -------
    log_probs : torch.Tensor — shape ``(B,)``
    entropy   : torch.Tensor — shape ``(B,)``
    values    : torch.Tensor — shape ``(B,)``
    """
    fused = self._encode_single(tokens, echelon, pad_mask, card, card_vec)

    mean = self.actor_head(fused)
    std = self.log_std.exp().expand_as(mean)
    dist = Normal(mean, std)

    log_probs = dist.log_prob(actions).sum(dim=-1)
    entropy = dist.entropy().sum(dim=-1)
    values = self.critic_head(fused).squeeze(-1)
    return log_probs, entropy, values

finetune_loss(batch)

Compute a supervised fine-tuning loss from a demonstration batch.

The batch dict must contain: * "tokens" — shape (B, N, token_dim) * "actions" — shape (B, action_dim)

Optional keys: * "pad_mask" — shape (B, N) * "echelon" — scalar int (default: ECHELON_BATTALION) * "card_vec" — shape (B, card_raw_dim) or (card_raw_dim,)

Returns:

Name Type Description
loss torch.Tensor — scalar behaviour-cloning MSE loss
Source code in models/wfm1.py
def finetune_loss(
    self,
    batch: Dict[str, torch.Tensor],
) -> torch.Tensor:
    """Compute a supervised fine-tuning loss from a demonstration batch.

    The batch dict must contain:
    * ``"tokens"``   — shape ``(B, N, token_dim)``
    * ``"actions"``  — shape ``(B, action_dim)``

    Optional keys:
    * ``"pad_mask"`` — shape ``(B, N)``
    * ``"echelon"``  — scalar int (default: ``ECHELON_BATTALION``)
    * ``"card_vec"`` — shape ``(B, card_raw_dim)`` or ``(card_raw_dim,)``

    Returns
    -------
    loss : torch.Tensor — scalar behaviour-cloning MSE loss
    """
    tokens = batch["tokens"]
    target_actions = batch["actions"]
    pad_mask = batch.get("pad_mask")
    echelon = int(batch["echelon"]) if "echelon" in batch else ECHELON_BATTALION
    card_vec = batch.get("card_vec")

    fused = self._encode_single(tokens, echelon, pad_mask, card_vec=card_vec)
    pred_mean = self.actor_head(fused)
    return nn.functional.mse_loss(pred_mean, target_actions)

freeze_base()

Freeze all base-model parameters; only the adapter remains trainable.

Source code in models/wfm1.py
def freeze_base(self) -> None:
    """Freeze all base-model parameters; only the adapter remains trainable."""
    for p in self.base_parameters():
        p.requires_grad_(False)

get_value(tokens, pad_mask=None, echelon=ECHELON_BATTALION, card=None, card_vec=None, tokens_per_echelon=None, pad_masks=None)

Compute scalar value estimates.

Returns:

Name Type Description
values torch.Tensor — shape ``(B,)``
Source code in models/wfm1.py
def get_value(
    self,
    tokens: torch.Tensor,
    pad_mask: Optional[torch.Tensor] = None,
    echelon: int = ECHELON_BATTALION,
    card: Optional[ScenarioCard] = None,
    card_vec: Optional[torch.Tensor] = None,
    tokens_per_echelon: Optional[Dict[int, torch.Tensor]] = None,
    pad_masks: Optional[Dict[int, torch.Tensor]] = None,
) -> torch.Tensor:
    """Compute scalar value estimates.

    Returns
    -------
    values : torch.Tensor — shape ``(B,)``
    """
    if tokens_per_echelon is not None:
        fused = self._encode(
            tokens_per_echelon, pad_masks, card=card, card_vec=card_vec
        )
    else:
        fused = self._encode_single(tokens, echelon, pad_mask, card, card_vec)

    return self.critic_head(fused).squeeze(-1)  # (B,)

load_checkpoint(path, map_location=None, **kwargs) classmethod

Load a WFM-1 checkpoint produced by :meth:save_checkpoint.

Extra keyword arguments override the saved configuration.

Source code in models/wfm1.py
@classmethod
def load_checkpoint(
    cls,
    path: str | Path,
    map_location: Optional[torch.device] = None,
    **kwargs,
) -> "WFM1Policy":
    """Load a WFM-1 checkpoint produced by :meth:`save_checkpoint`.

    Extra keyword arguments override the saved configuration.
    """
    ckpt = torch.load(path, map_location=map_location)
    cfg = {**ckpt["config"], **kwargs}
    policy = cls(**cfg)
    policy.load_state_dict(ckpt["state_dict"])
    return policy

save_checkpoint(path)

Save model state dict to path (.pt).

Source code in models/wfm1.py
def save_checkpoint(self, path: str | Path) -> Path:
    """Save model state dict to *path* (``.pt``)."""
    out = Path(path)
    out.parent.mkdir(parents=True, exist_ok=True)
    torch.save(
        {
            "state_dict": self.state_dict(),
            "config": {
                "token_dim": self.token_dim,
                "action_dim": self.action_dim,
                "d_model": self.d_model,
                "n_heads": self.echelon_encoders[0].encoder.n_heads,
                "n_echelon_layers": self.echelon_encoders[0].encoder.n_layers,
                "n_cross_layers": self.cross_echelon.transformer.num_layers,
                "actor_hidden_sizes": tuple(
                    layer.out_features
                    for layer in self.actor_head
                    if isinstance(layer, nn.Linear)
                )[:-1],
                "critic_hidden_sizes": tuple(
                    layer.out_features
                    for layer in self.critic_head
                    if isinstance(layer, nn.Linear)
                )[:-1],
                "dropout": self.dropout,
                "share_echelon_encoders": self.share_echelon_encoders,
                "card_hidden_size": self.card_encoder.mlp[0].out_features,
            },
        },
        out,
    )
    return out

unfreeze_base()

Unfreeze all base-model parameters.

Source code in models/wfm1.py
def unfreeze_base(self) -> None:
    """Unfreeze all base-model parameters."""
    for p in self.parameters():
        p.requires_grad_(True)

models.wfm1.ScenarioCard dataclass

Metadata descriptor for a training/evaluation scenario.

Used to condition WFM-1 via FiLM modulation. All numeric fields are expected to be pre-normalised to reasonable ranges; the policy encodes them into a conditioning vector via a small MLP.

Attributes:

Name Type Description
map_scale float

Map area normalised to [0, 1] (0 = smallest battalion map, 1 = largest corps map).

echelon_level int

Primary echelon of the scenario. One of :data:ECHELON_BATTALION, :data:ECHELON_BRIGADE, :data:ECHELON_DIVISION, :data:ECHELON_CORPS.

weather_code int

Active weather condition. One of :data:WEATHER_CLEAR, :data:WEATHER_RAIN, :data:WEATHER_FOG, :data:WEATHER_SNOW.

n_blue_units float

Number of blue (friendly) units, normalised to [0, 1] before encoding (divided by max_units; default 64).

n_red_units float

Number of red (enemy) units, similarly normalised.

terrain_type int

Integer terrain identifier. 0 = procedural; 1–4 = GIS sites.

cavalry_fraction float

Fraction of units that are cavalry ([0, 1]).

artillery_fraction float

Fraction of units that are artillery ([0, 1]).

supply_pressure float

Supply depletion index ([0, 1]; 0 = full supply, 1 = exhausted).

time_of_day float

Normalised time in [0, 1] (0 = dawn, 1 = dusk).

max_units float

Normalisation constant for unit counts. Not encoded — used only during the :meth:to_tensor conversion.

Source code in models/wfm1.py
@dataclass
class ScenarioCard:
    """Metadata descriptor for a training/evaluation scenario.

    Used to condition WFM-1 via FiLM modulation.  All numeric fields are
    expected to be pre-normalised to reasonable ranges; the policy encodes
    them into a conditioning vector via a small MLP.

    Attributes
    ----------
    map_scale:
        Map area normalised to ``[0, 1]`` (0 = smallest battalion map,
        1 = largest corps map).
    echelon_level:
        Primary echelon of the scenario.  One of :data:`ECHELON_BATTALION`,
        :data:`ECHELON_BRIGADE`, :data:`ECHELON_DIVISION`, :data:`ECHELON_CORPS`.
    weather_code:
        Active weather condition.  One of :data:`WEATHER_CLEAR`,
        :data:`WEATHER_RAIN`, :data:`WEATHER_FOG`, :data:`WEATHER_SNOW`.
    n_blue_units:
        Number of blue (friendly) units, normalised to ``[0, 1]`` before
        encoding (divided by ``max_units``; default 64).
    n_red_units:
        Number of red (enemy) units, similarly normalised.
    terrain_type:
        Integer terrain identifier.  0 = procedural; 1–4 = GIS sites.
    cavalry_fraction:
        Fraction of units that are cavalry (``[0, 1]``).
    artillery_fraction:
        Fraction of units that are artillery (``[0, 1]``).
    supply_pressure:
        Supply depletion index (``[0, 1]``; 0 = full supply, 1 = exhausted).
    time_of_day:
        Normalised time in ``[0, 1]`` (0 = dawn, 1 = dusk).
    max_units:
        Normalisation constant for unit counts.  Not encoded — used only
        during the :meth:`to_tensor` conversion.
    """

    map_scale: float = 0.5
    echelon_level: int = ECHELON_BATTALION
    weather_code: int = WEATHER_CLEAR
    n_blue_units: float = 8.0
    n_red_units: float = 8.0
    terrain_type: int = TERRAIN_PROCEDURAL
    cavalry_fraction: float = 0.0
    artillery_fraction: float = 0.0
    supply_pressure: float = 0.0
    time_of_day: float = 0.5
    max_units: float = 64.0

    def to_tensor(self, device: Optional[torch.device] = None) -> torch.Tensor:
        """Encode the card as a 1-D float tensor of shape ``(_SCENARIO_CARD_RAW_DIM,)``.

        Encoding layout (12 dims):
        * [0]   map_scale                    float [0, 1]
        * [1:5] echelon one-hot              4 dims
        * [5:9] weather one-hot              4 dims
        * [9]   n_blue_units / max_units     float [0, 1]
        * [10]  n_red_units / max_units      float [0, 1]
        * [11]  terrain_type / 4             float [0, 1]

        Unit counts (``n_blue_units``, ``n_red_units``) are normalised
        internally by dividing by ``max_units``; callers do not need to
        pre-normalise them.

        Note: the remaining floating-point fields (``cavalry_fraction``,
        ``artillery_fraction``, ``supply_pressure``, ``time_of_day``) are
        stored on the dataclass but are **not** included in this 12-dim
        vector.  They can be appended manually for experimental extensions.
        """
        vec = torch.zeros(_SCENARIO_CARD_RAW_DIM)
        vec[0] = float(self.map_scale)

        echelon_idx = max(0, min(self.echelon_level, _N_ECHELONS - 1))
        vec[1 + echelon_idx] = 1.0  # one-hot [1:5]

        weather_idx = max(0, min(self.weather_code, 3))
        vec[5 + weather_idx] = 1.0  # one-hot [5:9]

        vec[9] = float(self.n_blue_units) / max(self.max_units, 1.0)
        vec[10] = float(self.n_red_units) / max(self.max_units, 1.0)
        vec[11] = float(self.terrain_type) / 4.0

        if device is not None:
            vec = vec.to(device)
        return vec

to_tensor(device=None)

Encode the card as a 1-D float tensor of shape (_SCENARIO_CARD_RAW_DIM,).

Encoding layout (12 dims): * [0] map_scale float [0, 1] * [1:5] echelon one-hot 4 dims * [5:9] weather one-hot 4 dims * [9] n_blue_units / max_units float [0, 1] * [10] n_red_units / max_units float [0, 1] * [11] terrain_type / 4 float [0, 1]

Unit counts (n_blue_units, n_red_units) are normalised internally by dividing by max_units; callers do not need to pre-normalise them.

Note: the remaining floating-point fields (cavalry_fraction, artillery_fraction, supply_pressure, time_of_day) are stored on the dataclass but are not included in this 12-dim vector. They can be appended manually for experimental extensions.

Source code in models/wfm1.py
def to_tensor(self, device: Optional[torch.device] = None) -> torch.Tensor:
    """Encode the card as a 1-D float tensor of shape ``(_SCENARIO_CARD_RAW_DIM,)``.

    Encoding layout (12 dims):
    * [0]   map_scale                    float [0, 1]
    * [1:5] echelon one-hot              4 dims
    * [5:9] weather one-hot              4 dims
    * [9]   n_blue_units / max_units     float [0, 1]
    * [10]  n_red_units / max_units      float [0, 1]
    * [11]  terrain_type / 4             float [0, 1]

    Unit counts (``n_blue_units``, ``n_red_units``) are normalised
    internally by dividing by ``max_units``; callers do not need to
    pre-normalise them.

    Note: the remaining floating-point fields (``cavalry_fraction``,
    ``artillery_fraction``, ``supply_pressure``, ``time_of_day``) are
    stored on the dataclass but are **not** included in this 12-dim
    vector.  They can be appended manually for experimental extensions.
    """
    vec = torch.zeros(_SCENARIO_CARD_RAW_DIM)
    vec[0] = float(self.map_scale)

    echelon_idx = max(0, min(self.echelon_level, _N_ECHELONS - 1))
    vec[1 + echelon_idx] = 1.0  # one-hot [1:5]

    weather_idx = max(0, min(self.weather_code, 3))
    vec[5 + weather_idx] = 1.0  # one-hot [5:9]

    vec[9] = float(self.n_blue_units) / max(self.max_units, 1.0)
    vec[10] = float(self.n_red_units) / max(self.max_units, 1.0)
    vec[11] = float(self.terrain_type) / 4.0

    if device is not None:
        vec = vec.to(device)
    return vec

models.wfm1.EchelonEncoder

Bases: Module

Per-echelon entity encoder based on :class:~models.entity_encoder.EntityEncoder.

Wraps an :class:~models.entity_encoder.EntityEncoder and optionally adds an echelon embedding that is summed into the token embeddings before the transformer layers.

Parameters:

Name Type Description Default
token_dim int

Entity token dimensionality.

ENTITY_TOKEN_DIM
d_model int

Transformer hidden dimension.

128
n_heads int

Number of attention heads.

8
n_layers int

Number of transformer encoder layers.

4
dropout float

Dropout probability in transformer layers.

0.0
use_spatial_pe bool

Whether to add 2-D Fourier positional encoding.

True
n_freq_bands int

Fourier frequency bands for the spatial PE.

8
use_echelon_embedding bool

When True, learn a separate embedding per echelon level that is added to the projected token features before the transformer.

True
Source code in models/wfm1.py
class EchelonEncoder(nn.Module):
    """Per-echelon entity encoder based on :class:`~models.entity_encoder.EntityEncoder`.

    Wraps an :class:`~models.entity_encoder.EntityEncoder` and optionally
    adds an echelon embedding that is summed into the token embeddings before
    the transformer layers.

    Parameters
    ----------
    token_dim:
        Entity token dimensionality.
    d_model:
        Transformer hidden dimension.
    n_heads:
        Number of attention heads.
    n_layers:
        Number of transformer encoder layers.
    dropout:
        Dropout probability in transformer layers.
    use_spatial_pe:
        Whether to add 2-D Fourier positional encoding.
    n_freq_bands:
        Fourier frequency bands for the spatial PE.
    use_echelon_embedding:
        When ``True``, learn a separate embedding per echelon level that is
        added to the projected token features before the transformer.
    """

    def __init__(
        self,
        token_dim: int = ENTITY_TOKEN_DIM,
        d_model: int = 128,
        n_heads: int = 8,
        n_layers: int = 4,
        dropout: float = 0.0,
        use_spatial_pe: bool = True,
        n_freq_bands: int = 8,
        use_echelon_embedding: bool = True,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.use_echelon_embedding = use_echelon_embedding

        self.encoder = EntityEncoder(
            token_dim=token_dim,
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers,
            dropout=dropout,
            use_spatial_pe=use_spatial_pe,
            n_freq_bands=n_freq_bands,
        )

        if use_echelon_embedding:
            self.echelon_embed = nn.Embedding(_N_ECHELONS, d_model)
            nn.init.normal_(self.echelon_embed.weight, std=0.02)

    @property
    def output_dim(self) -> int:
        """Dimensionality of the pooled output."""
        return self.d_model

    def forward(
        self,
        tokens: torch.Tensor,
        echelon: int,
        pad_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Encode entity tokens for a given echelon level.

        Parameters
        ----------
        tokens:
            Entity token tensor of shape ``(B, N, token_dim)``.
        echelon:
            Integer echelon identifier (0–3).
        pad_mask:
            Boolean padding mask ``(B, N)``.  ``True`` = ignored position.

        Returns
        -------
        enc : torch.Tensor — shape ``(B, d_model)``
        """
        B, N, _ = tokens.shape

        # Get base token embeddings from the encoder's embed layer
        x = self.encoder.token_embed(tokens)  # (B, N, d_model)

        # Spatial positional encoding
        if self.encoder.use_spatial_pe:
            from models.entity_encoder import _SLICE_POSITION
            xy = tokens[..., _SLICE_POSITION]
            x = x + self.encoder.spatial_pe(xy)

        # Echelon embedding: broadcast over the entity sequence
        if self.use_echelon_embedding:
            echelon_int = int(echelon)
            num_echelons = self.echelon_embed.num_embeddings
            if not (0 <= echelon_int < num_echelons):
                raise ValueError(
                    f"Invalid echelon id {echelon!r}; expected integer in "
                    f"[0, {num_echelons - 1}]."
                )
            ech_idx = torch.tensor(echelon_int, device=tokens.device, dtype=torch.long)
            ech_emb = self.echelon_embed(ech_idx)  # (d_model,)
            x = x + ech_emb.unsqueeze(0).unsqueeze(0)  # (1, 1, d_model)

        # Transformer
        x = self.encoder.transformer(x, src_key_padding_mask=pad_mask)

        # Mean-pool over non-padded positions
        if pad_mask is not None:
            keep = ~pad_mask  # (B, N)
            n_valid = keep.float().sum(dim=1, keepdim=True).clamp(min=1.0)
            x = (x * keep.unsqueeze(-1).float()).sum(dim=1) / n_valid
        else:
            x = x.mean(dim=1)

        return self.encoder.out_proj(x)  # (B, d_model)

output_dim property

Dimensionality of the pooled output.

forward(tokens, echelon, pad_mask=None)

Encode entity tokens for a given echelon level.

Parameters:

Name Type Description Default
tokens Tensor

Entity token tensor of shape (B, N, token_dim).

required
echelon int

Integer echelon identifier (0–3).

required
pad_mask Optional[Tensor]

Boolean padding mask (B, N). True = ignored position.

None

Returns:

Name Type Description
enc torch.Tensor — shape ``(B, d_model)``
Source code in models/wfm1.py
def forward(
    self,
    tokens: torch.Tensor,
    echelon: int,
    pad_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Encode entity tokens for a given echelon level.

    Parameters
    ----------
    tokens:
        Entity token tensor of shape ``(B, N, token_dim)``.
    echelon:
        Integer echelon identifier (0–3).
    pad_mask:
        Boolean padding mask ``(B, N)``.  ``True`` = ignored position.

    Returns
    -------
    enc : torch.Tensor — shape ``(B, d_model)``
    """
    B, N, _ = tokens.shape

    # Get base token embeddings from the encoder's embed layer
    x = self.encoder.token_embed(tokens)  # (B, N, d_model)

    # Spatial positional encoding
    if self.encoder.use_spatial_pe:
        from models.entity_encoder import _SLICE_POSITION
        xy = tokens[..., _SLICE_POSITION]
        x = x + self.encoder.spatial_pe(xy)

    # Echelon embedding: broadcast over the entity sequence
    if self.use_echelon_embedding:
        echelon_int = int(echelon)
        num_echelons = self.echelon_embed.num_embeddings
        if not (0 <= echelon_int < num_echelons):
            raise ValueError(
                f"Invalid echelon id {echelon!r}; expected integer in "
                f"[0, {num_echelons - 1}]."
            )
        ech_idx = torch.tensor(echelon_int, device=tokens.device, dtype=torch.long)
        ech_emb = self.echelon_embed(ech_idx)  # (d_model,)
        x = x + ech_emb.unsqueeze(0).unsqueeze(0)  # (1, 1, d_model)

    # Transformer
    x = self.encoder.transformer(x, src_key_padding_mask=pad_mask)

    # Mean-pool over non-padded positions
    if pad_mask is not None:
        keep = ~pad_mask  # (B, N)
        n_valid = keep.float().sum(dim=1, keepdim=True).clamp(min=1.0)
        x = (x * keep.unsqueeze(-1).float()).sum(dim=1) / n_valid
    else:
        x = x.mean(dim=1)

    return self.encoder.out_proj(x)  # (B, d_model)

models.wfm1.CrossEchelonTransformer

Bases: Module

Transformer that integrates encodings from multiple echelon levels.

Takes a sequence of echelon encodings (B, E, d_model) where E is the number of active echelons, and applies multi-head self-attention to fuse information across echelon boundaries.

Parameters:

Name Type Description Default
d_model int

Hidden dimension (must match the echelon encoder output).

128
n_heads int

Number of attention heads.

8
n_layers int

Depth of the cross-echelon transformer.

2
dropout float

Dropout probability.

0.0
Source code in models/wfm1.py
class CrossEchelonTransformer(nn.Module):
    """Transformer that integrates encodings from multiple echelon levels.

    Takes a sequence of echelon encodings ``(B, E, d_model)`` where E is the
    number of active echelons, and applies multi-head self-attention to fuse
    information across echelon boundaries.

    Parameters
    ----------
    d_model:
        Hidden dimension (must match the echelon encoder output).
    n_heads:
        Number of attention heads.
    n_layers:
        Depth of the cross-echelon transformer.
    dropout:
        Dropout probability.
    """

    def __init__(
        self,
        d_model: int = 128,
        n_heads: int = 8,
        n_layers: int = 2,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.d_model = d_model

        # Positional embedding for echelon order (battalion < brigade < …)
        self.echelon_pos_embed = nn.Embedding(_N_ECHELONS, d_model)
        nn.init.normal_(self.echelon_pos_embed.weight, std=0.02)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=4 * d_model,
            dropout=dropout,
            batch_first=True,
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=n_layers,
            enable_nested_tensor=False,
        )

        self.out_norm = nn.LayerNorm(d_model)

    def forward(
        self,
        echelon_encs: torch.Tensor,
        echelon_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Fuse multi-echelon encodings.

        Parameters
        ----------
        echelon_encs:
            Stacked echelon encodings of shape ``(B, E, d_model)`` where
            E ≤ 4 is the number of active echelons.
        echelon_ids:
            Integer echelon identifiers of shape ``(E,)`` used to add
            order-aware positional embeddings.

        Returns
        -------
        fused : torch.Tensor — shape ``(B, d_model)``
            Mean-pooled fused representation.
        """
        # Add echelon-order positional embeddings
        pos = self.echelon_pos_embed(echelon_ids)  # (E, d_model)
        x = echelon_encs + pos.unsqueeze(0)        # (B, E, d_model)

        x = self.transformer(x)                    # (B, E, d_model)
        x = self.out_norm(x.mean(dim=1))           # (B, d_model)
        return x

forward(echelon_encs, echelon_ids)

Fuse multi-echelon encodings.

Parameters:

Name Type Description Default
echelon_encs Tensor

Stacked echelon encodings of shape (B, E, d_model) where E ≤ 4 is the number of active echelons.

required
echelon_ids Tensor

Integer echelon identifiers of shape (E,) used to add order-aware positional embeddings.

required

Returns:

Name Type Description
fused torch.Tensor — shape ``(B, d_model)``

Mean-pooled fused representation.

Source code in models/wfm1.py
def forward(
    self,
    echelon_encs: torch.Tensor,
    echelon_ids: torch.Tensor,
) -> torch.Tensor:
    """Fuse multi-echelon encodings.

    Parameters
    ----------
    echelon_encs:
        Stacked echelon encodings of shape ``(B, E, d_model)`` where
        E ≤ 4 is the number of active echelons.
    echelon_ids:
        Integer echelon identifiers of shape ``(E,)`` used to add
        order-aware positional embeddings.

    Returns
    -------
    fused : torch.Tensor — shape ``(B, d_model)``
        Mean-pooled fused representation.
    """
    # Add echelon-order positional embeddings
    pos = self.echelon_pos_embed(echelon_ids)  # (E, d_model)
    x = echelon_encs + pos.unsqueeze(0)        # (B, E, d_model)

    x = self.transformer(x)                    # (B, E, d_model)
    x = self.out_norm(x.mean(dim=1))           # (B, d_model)
    return x