Skip to content

Analysis API

The analysis package exposes course-of-action generation and policy saliency analysis as a clean Python API.

Quick-start

from analysis import COAGenerator, generate_coas, SaliencyAnalyzer

# Generate ranked courses of action
coas = generate_coas(env=env, policy=model, n_episodes=20)
for coa in coas:
    print(f"{coa.label}: win_rate={coa.score.win_rate:.1%}")

# Compute saliency for a policy
analyzer = SaliencyAnalyzer(policy=model, env=env)
saliency = analyzer.compute(obs, method="integrated_gradients")
analyzer.plot(saliency)

COA generation

analysis.coa_generator.COAGenerator

Generate and rank candidate Courses of Action via Monte-Carlo rollout.

Parameters:

Name Type Description Default
env BattalionEnv

A :class:~envs.battalion_env.BattalionEnv instance. The environment is reset before each rollout; the caller retains ownership and is responsible for closing it.

required
n_rollouts int

Number of Monte-Carlo rollouts per COA (default 20). More rollouts give more stable estimates but take longer.

20
n_coas int

Number of distinct COAs to generate (default 5, maximum is the number of built-in strategy archetypes, i.e. 7).

5
seed Optional[int]

Base random seed for reproducibility. Each COA uses a deterministic derived seed.

None
strategies Optional[Sequence[str]]

Explicit list of strategy labels to use. When None (the default) the first n_coas strategies from :data:STRATEGY_LABELS are used.

None
Source code in analysis/coa_generator.py
class COAGenerator:
    """Generate and rank candidate Courses of Action via Monte-Carlo rollout.

    Parameters
    ----------
    env:
        A :class:`~envs.battalion_env.BattalionEnv` instance.  The
        environment is reset before each rollout; the caller retains
        ownership and is responsible for closing it.
    n_rollouts:
        Number of Monte-Carlo rollouts per COA (default 20).  More
        rollouts give more stable estimates but take longer.
    n_coas:
        Number of distinct COAs to generate (default 5, maximum is the
        number of built-in strategy archetypes, i.e. 7).
    seed:
        Base random seed for reproducibility.  Each COA uses a
        deterministic derived seed.
    strategies:
        Explicit list of strategy labels to use.  When ``None`` (the
        default) the first ``n_coas`` strategies from
        :data:`STRATEGY_LABELS` are used.
    """

    def __init__(
        self,
        env: BattalionEnv,
        n_rollouts: int = 20,
        n_coas: int = 5,
        seed: Optional[int] = None,
        strategies: Optional[Sequence[str]] = None,
    ) -> None:
        if n_rollouts < 1:
            raise ValueError(f"n_rollouts must be >= 1, got {n_rollouts}")
        if n_coas < 1:
            raise ValueError(f"n_coas must be >= 1, got {n_coas}")
        if n_coas > len(STRATEGY_LABELS):
            raise ValueError(
                f"n_coas ({n_coas}) exceeds the number of built-in strategy "
                f"archetypes ({len(STRATEGY_LABELS)}).  Pass a custom "
                f"'strategies' list or reduce n_coas."
            )
        self.env = env
        self.n_rollouts = int(n_rollouts)
        self.n_coas = int(n_coas)
        self.seed = seed

        if strategies is not None:
            invalid = [s for s in strategies if s not in _STRATEGY_BIASES]
            if invalid:
                raise ValueError(
                    f"Unknown strategy labels: {invalid}. "
                    f"Valid: {sorted(_STRATEGY_BIASES)}"
                )
            # Deduplicate while preserving order to ensure distinct COAs.
            unique_strategies = list(dict.fromkeys(strategies))
            if len(unique_strategies) < self.n_coas:
                raise ValueError(
                    "Insufficient distinct strategy labels provided: "
                    f"expected at least {self.n_coas}, "
                    f"got {len(unique_strategies)} from {list(strategies)!r}"
                )
            self._strategies: List[str] = unique_strategies[: self.n_coas]
        else:
            self._strategies = list(STRATEGY_LABELS[: self.n_coas])

    def generate(
        self,
        policy: Optional[Any] = None,
        deterministic: bool = False,
    ) -> List[CourseOfAction]:
        """Generate and rank candidate COAs.

        Parameters
        ----------
        policy:
            Optional trained policy (e.g. a Stable-Baselines3 ``PPO`` model).
            Must expose ``predict(obs, deterministic) -> (action, state)``.
            When ``None`` a random action policy is used (useful for smoke
            tests and scenario exploration without a trained model).
        deterministic:
            Passed through to the base policy's ``predict`` call.  Ignored
            when *policy* is ``None``.

        Returns
        -------
        list of :class:`CourseOfAction`
            Ordered from best to worst by composite score.  The list has
            exactly ``n_coas`` entries.
        """
        coa_list: List[CourseOfAction] = []

        for coa_idx, strategy in enumerate(self._strategies):
            coa_seed = (
                (self.seed * 1000 + coa_idx * 37) if self.seed is not None
                else coa_idx * 37
            )
            rng = np.random.default_rng(coa_seed)
            biased = _BiasedPolicy(
                base_policy=policy,
                strategy=strategy,
                rng=rng,
            )

            rollout_results: List[dict] = []
            for rollout_i in range(self.n_rollouts):
                ep_seed = coa_seed + rollout_i
                result = _run_single_rollout(self.env, biased, seed=ep_seed,
                                            deterministic=deterministic)
                rollout_results.append(result)

            score, action_summary = _aggregate_rollouts(rollout_results)
            coa_list.append(
                CourseOfAction(
                    label=strategy,
                    rank=0,   # assigned after sorting
                    score=score,
                    action_summary=action_summary,
                    seed=coa_seed,
                )
            )

        # Sort by composite score (descending) and assign ranks.
        coa_list.sort(key=lambda c: c.score.composite, reverse=True)
        for rank, coa in enumerate(coa_list, start=1):
            coa.rank = rank

        return coa_list

generate(policy=None, deterministic=False)

Generate and rank candidate COAs.

Parameters:

Name Type Description Default
policy Optional[Any]

Optional trained policy (e.g. a Stable-Baselines3 PPO model). Must expose predict(obs, deterministic) -> (action, state). When None a random action policy is used (useful for smoke tests and scenario exploration without a trained model).

None
deterministic bool

Passed through to the base policy's predict call. Ignored when policy is None.

False

Returns:

Type Description
list of :class:`CourseOfAction`

Ordered from best to worst by composite score. The list has exactly n_coas entries.

Source code in analysis/coa_generator.py
def generate(
    self,
    policy: Optional[Any] = None,
    deterministic: bool = False,
) -> List[CourseOfAction]:
    """Generate and rank candidate COAs.

    Parameters
    ----------
    policy:
        Optional trained policy (e.g. a Stable-Baselines3 ``PPO`` model).
        Must expose ``predict(obs, deterministic) -> (action, state)``.
        When ``None`` a random action policy is used (useful for smoke
        tests and scenario exploration without a trained model).
    deterministic:
        Passed through to the base policy's ``predict`` call.  Ignored
        when *policy* is ``None``.

    Returns
    -------
    list of :class:`CourseOfAction`
        Ordered from best to worst by composite score.  The list has
        exactly ``n_coas`` entries.
    """
    coa_list: List[CourseOfAction] = []

    for coa_idx, strategy in enumerate(self._strategies):
        coa_seed = (
            (self.seed * 1000 + coa_idx * 37) if self.seed is not None
            else coa_idx * 37
        )
        rng = np.random.default_rng(coa_seed)
        biased = _BiasedPolicy(
            base_policy=policy,
            strategy=strategy,
            rng=rng,
        )

        rollout_results: List[dict] = []
        for rollout_i in range(self.n_rollouts):
            ep_seed = coa_seed + rollout_i
            result = _run_single_rollout(self.env, biased, seed=ep_seed,
                                        deterministic=deterministic)
            rollout_results.append(result)

        score, action_summary = _aggregate_rollouts(rollout_results)
        coa_list.append(
            CourseOfAction(
                label=strategy,
                rank=0,   # assigned after sorting
                score=score,
                action_summary=action_summary,
                seed=coa_seed,
            )
        )

    # Sort by composite score (descending) and assign ranks.
    coa_list.sort(key=lambda c: c.score.composite, reverse=True)
    for rank, coa in enumerate(coa_list, start=1):
        coa.rank = rank

    return coa_list

analysis.coa_generator.CorpsCOAGenerator

Generate and rank corps-level Courses of Action via Monte-Carlo rollout.

Satisfies the E9.2 requirements: * Up to 10 COAs generated via :meth:generate. * COA explanation via :meth:explain_coa (≥ 3 key decisions per COA). * COA modification and re-evaluation via :meth:modify_and_evaluate.

Parameters:

Name Type Description Default
env Any

A :class:~envs.corps_env.CorpsEnv instance. The caller retains ownership and is responsible for closing it.

required
n_rollouts int

Number of Monte-Carlo rollouts per COA (default 10).

10
n_coas int

Number of distinct COAs to generate (1–10, default 10).

10
seed Optional[int]

Base random seed for reproducibility.

None
strategies Optional[Sequence[str]]

Explicit list of strategy labels to evaluate. When None the first n_coas labels from :data:CORPS_STRATEGY_LABELS are used.

None
Source code in analysis/coa_generator.py
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
class CorpsCOAGenerator:
    """Generate and rank corps-level Courses of Action via Monte-Carlo rollout.

    Satisfies the E9.2 requirements:
    * Up to 10 COAs generated via :meth:`generate`.
    * COA explanation via :meth:`explain_coa` (≥ 3 key decisions per COA).
    * COA modification and re-evaluation via :meth:`modify_and_evaluate`.

    Parameters
    ----------
    env:
        A :class:`~envs.corps_env.CorpsEnv` instance.  The caller retains
        ownership and is responsible for closing it.
    n_rollouts:
        Number of Monte-Carlo rollouts per COA (default 10).
    n_coas:
        Number of distinct COAs to generate (1–10, default 10).
    seed:
        Base random seed for reproducibility.
    strategies:
        Explicit list of strategy labels to evaluate.  When ``None`` the
        first ``n_coas`` labels from :data:`CORPS_STRATEGY_LABELS` are used.
    """

    def __init__(
        self,
        env: Any,
        n_rollouts: int = 10,
        n_coas: int = 10,
        seed: Optional[int] = None,
        strategies: Optional[Sequence[str]] = None,
    ) -> None:
        if n_rollouts < 1:
            raise ValueError(f"n_rollouts must be >= 1, got {n_rollouts}")
        if n_coas < 1:
            raise ValueError(f"n_coas must be >= 1, got {n_coas}")
        if n_coas > len(CORPS_STRATEGY_LABELS):
            raise ValueError(
                f"n_coas ({n_coas}) exceeds the number of built-in corps strategy "
                f"archetypes ({len(CORPS_STRATEGY_LABELS)}).  Pass a custom "
                f"'strategies' list or reduce n_coas."
            )
        self.env = env
        self.n_rollouts = int(n_rollouts)
        self.n_coas = int(n_coas)
        self.seed = seed

        if strategies is not None:
            invalid = [s for s in strategies if s not in _CORPS_STRATEGY_PATTERNS]
            if invalid:
                raise ValueError(
                    f"Unknown corps strategy labels: {invalid}.  "
                    f"Valid: {sorted(_CORPS_STRATEGY_PATTERNS)}"
                )
            unique = list(dict.fromkeys(strategies))
            if len(unique) < self.n_coas:
                raise ValueError(
                    f"Insufficient distinct strategy labels: expected >= {self.n_coas}, "
                    f"got {len(unique)}"
                )
            self._strategies: List[str] = unique[: self.n_coas]
        else:
            self._strategies = list(CORPS_STRATEGY_LABELS[: self.n_coas])

        # Cache env properties needed for policy construction.
        self._n_divisions: int = int(getattr(env, "n_divisions", 3))
        self._n_corps_options: int = int(getattr(env, "n_corps_options", 6))

        # Store raw rollout results per strategy for explanation purposes.
        self._last_rollout_results: Dict[str, List[dict]] = {}

    def generate(
        self,
        policy: Optional[Any] = None,
        deterministic: bool = False,
    ) -> List[CorpsCourseOfAction]:
        """Generate and rank candidate corps-level COAs.

        Parameters
        ----------
        policy:
            Optional trained policy with
            ``predict(obs, deterministic) -> (action, state)``.  When
            ``None`` a strategy-biased random policy is used.
        deterministic:
            Passed through to the policy's ``predict`` call (only relevant
            for learned base policies).

        Returns
        -------
        list of :class:`CorpsCourseOfAction`
            Ordered best to worst by composite score.
        """
        coa_list: List[CorpsCourseOfAction] = []
        self._last_rollout_results = {}

        for coa_idx, strategy in enumerate(self._strategies):
            coa_seed = (
                (self.seed * 1000 + coa_idx * 37) if self.seed is not None
                else coa_idx * 37
            )
            rng = np.random.default_rng(coa_seed)
            corps_policy = _CorpsStrategyPolicy(
                n_divisions=self._n_divisions,
                n_corps_options=self._n_corps_options,
                strategy=strategy,
                rng=rng,
            )

            # If a trained base policy is supplied, wrap it with override
            # capability rather than ignoring it.
            if policy is not None:
                corps_policy = _WrappedCorpsPolicy(
                    base_policy=policy,
                    strategy_policy=corps_policy,
                    n_divisions=self._n_divisions,
                    n_corps_options=self._n_corps_options,
                    bias_strength=_CORPS_BIAS_STRENGTH,
                    rng=rng,
                )

            rollout_results: List[dict] = []
            for rollout_i in range(self.n_rollouts):
                ep_seed = coa_seed + rollout_i
                result = _run_corps_rollout(
                    self.env, corps_policy, seed=ep_seed,
                    deterministic=deterministic,
                )
                rollout_results.append(result)

            self._last_rollout_results[strategy] = rollout_results

            score, action_summary = _aggregate_corps_rollouts(
                rollout_results, self._n_divisions, self._n_corps_options
            )
            coa_list.append(
                CorpsCourseOfAction(
                    label=strategy,
                    rank=0,
                    score=score,
                    action_summary=action_summary,
                    seed=coa_seed,
                )
            )

        coa_list.sort(key=lambda c: c.score.composite, reverse=True)
        for rank, coa in enumerate(coa_list, start=1):
            coa.rank = rank

        return coa_list

    def explain_coa(self, coa: CorpsCourseOfAction) -> COAExplanation:
        """Explain the key decisions that drive a COA's outcome.

        Analyses the stored rollout results for ``coa.label`` and returns
        a :class:`COAExplanation` with ≥ 3 key decisions.

        Parameters
        ----------
        coa:
            A :class:`CorpsCourseOfAction` previously returned by
            :meth:`generate`.

        Returns
        -------
        :class:`COAExplanation`
        """
        results = self._last_rollout_results.get(coa.label, [])
        if not results:
            # Re-run rollouts if needed (e.g. after pickle/restore).
            return self._explain_from_scratch(coa)

        wins  = [r for r in results if r["outcome"] == 1]
        losses = [r for r in results if r["outcome"] == -1]
        all_r  = results

        cmd_names = [_CMD_NAMES.get(c, str(c)) for c in range(self._n_corps_options)]

        # ── Command frequency (overall) ───────────────────────────────────
        cmd_freq: Dict[str, float] = {name: 0.0 for name in cmd_names}
        total_cmds = 0
        for r in all_r:
            arr = r["actions"]
            if len(arr) == 0:
                continue
            for c_idx, cname in enumerate(cmd_names):
                cmd_freq[cname] += float(np.sum(arr == c_idx))
            total_cmds += arr.size
        if total_cmds > 0:
            cmd_freq = {k: round(v / total_cmds, 4) for k, v in cmd_freq.items()}

        # ── Objective reward timeline per quartile ─────────────────────────
        obj_timeline: Dict[str, float] = {"q1": 0.0, "q2": 0.0, "q3": 0.0, "q4": 0.0}
        for qi, qkey in enumerate(["q1", "q2", "q3", "q4"], start=1):
            vals: List[float] = []
            for r in all_r:
                arr = r["objective_rewards"]
                T = len(arr)
                if T == 0:
                    continue
                q = max(T // 4, 1)
                slices = [arr[:q], arr[q: 2*q], arr[2*q: 3*q], arr[3*q:]]
                sl = slices[qi - 1]
                if len(sl) == 0:
                    sl = arr
                vals.append(float(np.sum(sl)))
            obj_timeline[qkey] = round(float(np.mean(vals)), 4) if vals else 0.0

        # ── Winning command patterns ───────────────────────────────────────
        # For each winning rollout, extract the most frequent command per div.
        def _dominant_commands(r: dict) -> List[str]:
            arr = r["actions"]
            if len(arr) == 0:
                return []
            cmds: List[str] = []
            for div_i in range(self._n_divisions):
                col = arr[:, div_i]
                dom = int(np.bincount(col, minlength=self._n_corps_options).argmax())
                cmds.append(_CMD_NAMES.get(dom, str(dom)))
            return cmds

        pattern_counts: Dict[str, int] = {}
        for r in wins:
            pat = tuple(_dominant_commands(r))
            key = str(list(pat))
            pattern_counts[key] = pattern_counts.get(key, 0) + 1

        # Sort by frequency; return top-3 as lists.
        sorted_patterns = sorted(pattern_counts.items(), key=lambda x: x[1], reverse=True)
        winning_patterns: List[List[str]] = []
        for key, _cnt in sorted_patterns[:3]:
            import ast
            try:
                winning_patterns.append(ast.literal_eval(key))
            except (ValueError, SyntaxError):
                winning_patterns.append([key])

        # ── Key decisions ─────────────────────────────────────────────────
        key_decisions: List[str] = []

        # Decision 1: dominant strategy command and its impact on win rate.
        dominant_cmd = max(cmd_freq, key=lambda k: cmd_freq[k])
        key_decisions.append(
            f"Issuing '{dominant_cmd}' most frequently ({cmd_freq[dominant_cmd]*100:.1f}% "
            f"of all orders) is the defining action of this COA."
        )

        # Decision 2: phase where objectives are maximally gained.
        best_phase = max(obj_timeline, key=lambda k: obj_timeline[k])
        key_decisions.append(
            f"Highest objective gain occurs in {best_phase} (score: "
            f"{obj_timeline[best_phase]:.3f}), suggesting the critical push "
            f"happens in the {'first' if best_phase == 'q1' else 'second' if best_phase == 'q2' else 'third' if best_phase == 'q3' else 'final'} quarter of the episode."
        )

        # Decision 3: casualty trade-off.
        blue_cas = coa.score.blue_casualties
        red_cas  = coa.score.red_casualties
        trade_off = "favourable" if red_cas > blue_cas else "costly"
        key_decisions.append(
            f"Casualty trade-off is {trade_off}: Blue loses "
            f"{blue_cas*100:.1f}% vs Red loses {red_cas*100:.1f}% of units."
        )

        # Decision 4: supply impact.
        key_decisions.append(
            f"Mean Blue supply efficiency is {coa.score.supply_efficiency*100:.1f}%; "
            f"{'supply is well maintained — sustaining this COA is feasible.' if coa.score.supply_efficiency >= 0.5 else 'supply is degraded — this COA strains logistics.'}"
        )

        # Decision 5: win-rate context.
        win_pct = coa.score.win_rate * 100
        key_decisions.append(
            f"This COA wins {win_pct:.1f}% of rollouts "
            f"({'high' if win_pct >= 60 else 'moderate' if win_pct >= 40 else 'low'} reliability)."
        )

        return COAExplanation(
            coa_label=coa.label,
            key_decisions=key_decisions,
            command_frequency=cmd_freq,
            winning_patterns=winning_patterns,
            objective_timeline=obj_timeline,
        )

    def _explain_from_scratch(self, coa: CorpsCourseOfAction) -> COAExplanation:
        """Fallback explanation when rollout results are unavailable."""
        return COAExplanation(
            coa_label=coa.label,
            key_decisions=[
                f"Strategy '{coa.label}' achieves composite score {coa.score.composite:.4f}.",
                f"Win rate: {coa.score.win_rate*100:.1f}%, "
                f"casualty efficiency: {(coa.score.red_casualties - coa.score.blue_casualties + 1)/2*100:.1f}%.",
                f"Objective completion: {coa.score.objective_completion*100:.1f}%, "
                f"supply efficiency: {coa.score.supply_efficiency*100:.1f}%.",
            ],
            command_frequency={},
            winning_patterns=[],
            objective_timeline={"q1": 0.0, "q2": 0.0, "q3": 0.0, "q4": 0.0},
        )

    def modify_and_evaluate(
        self,
        coa: CorpsCourseOfAction,
        modification: COAModification,
    ) -> CorpsCourseOfAction:
        """Apply user modifications to a COA and re-simulate it.

        Parameters
        ----------
        coa:
            The original :class:`CorpsCourseOfAction` to modify.
        modification:
            A :class:`COAModification` describing the changes.

        Returns
        -------
        A new :class:`CorpsCourseOfAction` with updated score and the
        modified strategy label (or original label if no override).
        """
        strategy = modification.strategy_override or coa.label
        if strategy not in _CORPS_STRATEGY_PATTERNS:
            raise ValueError(
                f"Unknown strategy override '{strategy}'.  "
                f"Valid: {sorted(_CORPS_STRATEGY_PATTERNS)}"
            )
        raw_n_rollouts = modification.n_rollouts
        if raw_n_rollouts is None:
            n_rollouts = self.n_rollouts
        else:
            if raw_n_rollouts < 1:
                raise ValueError(
                    f"n_rollouts must be at least 1, got {raw_n_rollouts!r}"
                )
            n_rollouts = raw_n_rollouts
        overrides  = modification.division_command_overrides or {}

        rng = np.random.default_rng(coa.seed + 999)  # distinct seed from original
        corps_policy = _CorpsStrategyPolicy(
            n_divisions=self._n_divisions,
            n_corps_options=self._n_corps_options,
            strategy=strategy,
            rng=rng,
            division_command_overrides=overrides,
        )

        rollout_results: List[dict] = []
        for rollout_i in range(n_rollouts):
            ep_seed = coa.seed + 999 + rollout_i
            result = _run_corps_rollout(
                self.env, corps_policy, seed=ep_seed, deterministic=False
            )
            rollout_results.append(result)

        # Store for explain_coa use.
        self._last_rollout_results[strategy] = rollout_results

        score, action_summary = _aggregate_corps_rollouts(
            rollout_results, self._n_divisions, self._n_corps_options
        )
        return CorpsCourseOfAction(
            label=strategy,
            rank=0,  # caller should re-rank if needed
            score=score,
            action_summary=action_summary,
            seed=coa.seed + 999,
        )

explain_coa(coa)

Explain the key decisions that drive a COA's outcome.

Analyses the stored rollout results for coa.label and returns a :class:COAExplanation with ≥ 3 key decisions.

Parameters:

Name Type Description Default
coa CorpsCourseOfAction

A :class:CorpsCourseOfAction previously returned by :meth:generate.

required

Returns:

Type Description
class:`COAExplanation`
Source code in analysis/coa_generator.py
def explain_coa(self, coa: CorpsCourseOfAction) -> COAExplanation:
    """Explain the key decisions that drive a COA's outcome.

    Analyses the stored rollout results for ``coa.label`` and returns
    a :class:`COAExplanation` with ≥ 3 key decisions.

    Parameters
    ----------
    coa:
        A :class:`CorpsCourseOfAction` previously returned by
        :meth:`generate`.

    Returns
    -------
    :class:`COAExplanation`
    """
    results = self._last_rollout_results.get(coa.label, [])
    if not results:
        # Re-run rollouts if needed (e.g. after pickle/restore).
        return self._explain_from_scratch(coa)

    wins  = [r for r in results if r["outcome"] == 1]
    losses = [r for r in results if r["outcome"] == -1]
    all_r  = results

    cmd_names = [_CMD_NAMES.get(c, str(c)) for c in range(self._n_corps_options)]

    # ── Command frequency (overall) ───────────────────────────────────
    cmd_freq: Dict[str, float] = {name: 0.0 for name in cmd_names}
    total_cmds = 0
    for r in all_r:
        arr = r["actions"]
        if len(arr) == 0:
            continue
        for c_idx, cname in enumerate(cmd_names):
            cmd_freq[cname] += float(np.sum(arr == c_idx))
        total_cmds += arr.size
    if total_cmds > 0:
        cmd_freq = {k: round(v / total_cmds, 4) for k, v in cmd_freq.items()}

    # ── Objective reward timeline per quartile ─────────────────────────
    obj_timeline: Dict[str, float] = {"q1": 0.0, "q2": 0.0, "q3": 0.0, "q4": 0.0}
    for qi, qkey in enumerate(["q1", "q2", "q3", "q4"], start=1):
        vals: List[float] = []
        for r in all_r:
            arr = r["objective_rewards"]
            T = len(arr)
            if T == 0:
                continue
            q = max(T // 4, 1)
            slices = [arr[:q], arr[q: 2*q], arr[2*q: 3*q], arr[3*q:]]
            sl = slices[qi - 1]
            if len(sl) == 0:
                sl = arr
            vals.append(float(np.sum(sl)))
        obj_timeline[qkey] = round(float(np.mean(vals)), 4) if vals else 0.0

    # ── Winning command patterns ───────────────────────────────────────
    # For each winning rollout, extract the most frequent command per div.
    def _dominant_commands(r: dict) -> List[str]:
        arr = r["actions"]
        if len(arr) == 0:
            return []
        cmds: List[str] = []
        for div_i in range(self._n_divisions):
            col = arr[:, div_i]
            dom = int(np.bincount(col, minlength=self._n_corps_options).argmax())
            cmds.append(_CMD_NAMES.get(dom, str(dom)))
        return cmds

    pattern_counts: Dict[str, int] = {}
    for r in wins:
        pat = tuple(_dominant_commands(r))
        key = str(list(pat))
        pattern_counts[key] = pattern_counts.get(key, 0) + 1

    # Sort by frequency; return top-3 as lists.
    sorted_patterns = sorted(pattern_counts.items(), key=lambda x: x[1], reverse=True)
    winning_patterns: List[List[str]] = []
    for key, _cnt in sorted_patterns[:3]:
        import ast
        try:
            winning_patterns.append(ast.literal_eval(key))
        except (ValueError, SyntaxError):
            winning_patterns.append([key])

    # ── Key decisions ─────────────────────────────────────────────────
    key_decisions: List[str] = []

    # Decision 1: dominant strategy command and its impact on win rate.
    dominant_cmd = max(cmd_freq, key=lambda k: cmd_freq[k])
    key_decisions.append(
        f"Issuing '{dominant_cmd}' most frequently ({cmd_freq[dominant_cmd]*100:.1f}% "
        f"of all orders) is the defining action of this COA."
    )

    # Decision 2: phase where objectives are maximally gained.
    best_phase = max(obj_timeline, key=lambda k: obj_timeline[k])
    key_decisions.append(
        f"Highest objective gain occurs in {best_phase} (score: "
        f"{obj_timeline[best_phase]:.3f}), suggesting the critical push "
        f"happens in the {'first' if best_phase == 'q1' else 'second' if best_phase == 'q2' else 'third' if best_phase == 'q3' else 'final'} quarter of the episode."
    )

    # Decision 3: casualty trade-off.
    blue_cas = coa.score.blue_casualties
    red_cas  = coa.score.red_casualties
    trade_off = "favourable" if red_cas > blue_cas else "costly"
    key_decisions.append(
        f"Casualty trade-off is {trade_off}: Blue loses "
        f"{blue_cas*100:.1f}% vs Red loses {red_cas*100:.1f}% of units."
    )

    # Decision 4: supply impact.
    key_decisions.append(
        f"Mean Blue supply efficiency is {coa.score.supply_efficiency*100:.1f}%; "
        f"{'supply is well maintained — sustaining this COA is feasible.' if coa.score.supply_efficiency >= 0.5 else 'supply is degraded — this COA strains logistics.'}"
    )

    # Decision 5: win-rate context.
    win_pct = coa.score.win_rate * 100
    key_decisions.append(
        f"This COA wins {win_pct:.1f}% of rollouts "
        f"({'high' if win_pct >= 60 else 'moderate' if win_pct >= 40 else 'low'} reliability)."
    )

    return COAExplanation(
        coa_label=coa.label,
        key_decisions=key_decisions,
        command_frequency=cmd_freq,
        winning_patterns=winning_patterns,
        objective_timeline=obj_timeline,
    )

generate(policy=None, deterministic=False)

Generate and rank candidate corps-level COAs.

Parameters:

Name Type Description Default
policy Optional[Any]

Optional trained policy with predict(obs, deterministic) -> (action, state). When None a strategy-biased random policy is used.

None
deterministic bool

Passed through to the policy's predict call (only relevant for learned base policies).

False

Returns:

Type Description
list of :class:`CorpsCourseOfAction`

Ordered best to worst by composite score.

Source code in analysis/coa_generator.py
def generate(
    self,
    policy: Optional[Any] = None,
    deterministic: bool = False,
) -> List[CorpsCourseOfAction]:
    """Generate and rank candidate corps-level COAs.

    Parameters
    ----------
    policy:
        Optional trained policy with
        ``predict(obs, deterministic) -> (action, state)``.  When
        ``None`` a strategy-biased random policy is used.
    deterministic:
        Passed through to the policy's ``predict`` call (only relevant
        for learned base policies).

    Returns
    -------
    list of :class:`CorpsCourseOfAction`
        Ordered best to worst by composite score.
    """
    coa_list: List[CorpsCourseOfAction] = []
    self._last_rollout_results = {}

    for coa_idx, strategy in enumerate(self._strategies):
        coa_seed = (
            (self.seed * 1000 + coa_idx * 37) if self.seed is not None
            else coa_idx * 37
        )
        rng = np.random.default_rng(coa_seed)
        corps_policy = _CorpsStrategyPolicy(
            n_divisions=self._n_divisions,
            n_corps_options=self._n_corps_options,
            strategy=strategy,
            rng=rng,
        )

        # If a trained base policy is supplied, wrap it with override
        # capability rather than ignoring it.
        if policy is not None:
            corps_policy = _WrappedCorpsPolicy(
                base_policy=policy,
                strategy_policy=corps_policy,
                n_divisions=self._n_divisions,
                n_corps_options=self._n_corps_options,
                bias_strength=_CORPS_BIAS_STRENGTH,
                rng=rng,
            )

        rollout_results: List[dict] = []
        for rollout_i in range(self.n_rollouts):
            ep_seed = coa_seed + rollout_i
            result = _run_corps_rollout(
                self.env, corps_policy, seed=ep_seed,
                deterministic=deterministic,
            )
            rollout_results.append(result)

        self._last_rollout_results[strategy] = rollout_results

        score, action_summary = _aggregate_corps_rollouts(
            rollout_results, self._n_divisions, self._n_corps_options
        )
        coa_list.append(
            CorpsCourseOfAction(
                label=strategy,
                rank=0,
                score=score,
                action_summary=action_summary,
                seed=coa_seed,
            )
        )

    coa_list.sort(key=lambda c: c.score.composite, reverse=True)
    for rank, coa in enumerate(coa_list, start=1):
        coa.rank = rank

    return coa_list

modify_and_evaluate(coa, modification)

Apply user modifications to a COA and re-simulate it.

Parameters:

Name Type Description Default
coa CorpsCourseOfAction

The original :class:CorpsCourseOfAction to modify.

required
modification COAModification

A :class:COAModification describing the changes.

required

Returns:

Type Description
A new :class:`CorpsCourseOfAction` with updated score and the
modified strategy label (or original label if no override).
Source code in analysis/coa_generator.py
def modify_and_evaluate(
    self,
    coa: CorpsCourseOfAction,
    modification: COAModification,
) -> CorpsCourseOfAction:
    """Apply user modifications to a COA and re-simulate it.

    Parameters
    ----------
    coa:
        The original :class:`CorpsCourseOfAction` to modify.
    modification:
        A :class:`COAModification` describing the changes.

    Returns
    -------
    A new :class:`CorpsCourseOfAction` with updated score and the
    modified strategy label (or original label if no override).
    """
    strategy = modification.strategy_override or coa.label
    if strategy not in _CORPS_STRATEGY_PATTERNS:
        raise ValueError(
            f"Unknown strategy override '{strategy}'.  "
            f"Valid: {sorted(_CORPS_STRATEGY_PATTERNS)}"
        )
    raw_n_rollouts = modification.n_rollouts
    if raw_n_rollouts is None:
        n_rollouts = self.n_rollouts
    else:
        if raw_n_rollouts < 1:
            raise ValueError(
                f"n_rollouts must be at least 1, got {raw_n_rollouts!r}"
            )
        n_rollouts = raw_n_rollouts
    overrides  = modification.division_command_overrides or {}

    rng = np.random.default_rng(coa.seed + 999)  # distinct seed from original
    corps_policy = _CorpsStrategyPolicy(
        n_divisions=self._n_divisions,
        n_corps_options=self._n_corps_options,
        strategy=strategy,
        rng=rng,
        division_command_overrides=overrides,
    )

    rollout_results: List[dict] = []
    for rollout_i in range(n_rollouts):
        ep_seed = coa.seed + 999 + rollout_i
        result = _run_corps_rollout(
            self.env, corps_policy, seed=ep_seed, deterministic=False
        )
        rollout_results.append(result)

    # Store for explain_coa use.
    self._last_rollout_results[strategy] = rollout_results

    score, action_summary = _aggregate_corps_rollouts(
        rollout_results, self._n_divisions, self._n_corps_options
    )
    return CorpsCourseOfAction(
        label=strategy,
        rank=0,  # caller should re-rank if needed
        score=score,
        action_summary=action_summary,
        seed=coa.seed + 999,
    )

analysis.coa_generator.COAScore dataclass

Scalar metrics summarising the outcomes of one COA's Monte-Carlo rollouts.

Attributes:

Name Type Description
win_rate float

Fraction of rollouts won by Blue (0–1).

draw_rate float

Fraction of rollouts that ended as a draw (0–1).

loss_rate float

Fraction of rollouts lost by Blue (0–1).

blue_casualties float

Mean normalised Blue strength loss across rollouts (0–1). Higher means more Blue casualties; 0 = no damage taken.

red_casualties float

Mean normalised Red strength loss across rollouts (0–1). Higher means Blue dealt more damage to Red.

terrain_control float

Mean fraction of steps in which Blue held terrain advantage (closer to map centre than Red), averaged across rollouts (0–1).

composite float

Weighted composite score used for ranking (higher is better).

n_rollouts int

Number of rollouts used to compute these statistics.

Source code in analysis/coa_generator.py
@dataclasses.dataclass
class COAScore:
    """Scalar metrics summarising the outcomes of one COA's Monte-Carlo rollouts.

    Attributes
    ----------
    win_rate:
        Fraction of rollouts won by Blue (0–1).
    draw_rate:
        Fraction of rollouts that ended as a draw (0–1).
    loss_rate:
        Fraction of rollouts lost by Blue (0–1).
    blue_casualties:
        Mean normalised Blue strength *loss* across rollouts (0–1).
        Higher means more Blue casualties; 0 = no damage taken.
    red_casualties:
        Mean normalised Red strength *loss* across rollouts (0–1).
        Higher means Blue dealt more damage to Red.
    terrain_control:
        Mean fraction of steps in which Blue held terrain advantage
        (closer to map centre than Red), averaged across rollouts (0–1).
    composite:
        Weighted composite score used for ranking (higher is better).
    n_rollouts:
        Number of rollouts used to compute these statistics.
    """

    win_rate: float
    draw_rate: float
    loss_rate: float
    blue_casualties: float
    red_casualties: float
    terrain_control: float
    composite: float
    n_rollouts: int

    def as_dict(self) -> dict:
        """Return a plain ``dict`` representation."""
        return dataclasses.asdict(self)

as_dict()

Return a plain dict representation.

Source code in analysis/coa_generator.py
def as_dict(self) -> dict:
    """Return a plain ``dict`` representation."""
    return dataclasses.asdict(self)

analysis.coa_generator.CourseOfAction dataclass

A candidate tactical plan with associated outcome predictions.

Attributes:

Name Type Description
label str

Human-readable name of the tactical archetype.

rank int

Rank among all generated COAs (1 = best, higher = worse).

score COAScore

Aggregated outcome statistics from Monte-Carlo rollouts.

action_summary dict

Aggregate action statistics across rollouts: mean move, rotate, and fire per quartile of the episode.

seed int

Base random seed used to initialise this COA's rollouts.

Source code in analysis/coa_generator.py
@dataclasses.dataclass
class CourseOfAction:
    """A candidate tactical plan with associated outcome predictions.

    Attributes
    ----------
    label:
        Human-readable name of the tactical archetype.
    rank:
        Rank among all generated COAs (1 = best, higher = worse).
    score:
        Aggregated outcome statistics from Monte-Carlo rollouts.
    action_summary:
        Aggregate action statistics across rollouts:
        mean ``move``, ``rotate``, and ``fire`` per quartile of the episode.
    seed:
        Base random seed used to initialise this COA's rollouts.
    """

    label: str
    rank: int
    score: COAScore
    action_summary: dict
    seed: int

    def as_dict(self) -> dict:
        """Return a JSON-serialisable ``dict``."""
        return {
            "label": self.label,
            "rank": self.rank,
            "score": self.score.as_dict(),
            "action_summary": self.action_summary,
            "seed": self.seed,
        }

as_dict()

Return a JSON-serialisable dict.

Source code in analysis/coa_generator.py
def as_dict(self) -> dict:
    """Return a JSON-serialisable ``dict``."""
    return {
        "label": self.label,
        "rank": self.rank,
        "score": self.score.as_dict(),
        "action_summary": self.action_summary,
        "seed": self.seed,
    }

analysis.coa_generator.generate_coas(env=None, policy=None, n_rollouts=20, n_coas=5, seed=None, strategies=None, env_kwargs=None)

Generate COAs, optionally creating a temporary environment.

This is a thin convenience wrapper around :class:COAGenerator.

Parameters:

Name Type Description Default
env Optional[BattalionEnv]

An existing :class:~envs.battalion_env.BattalionEnv. When None a new environment is created using env_kwargs and closed automatically when generation completes.

None
policy Optional[Any]

Optional trained policy (see :meth:COAGenerator.generate).

None
n_rollouts int

Number of Monte-Carlo rollouts per COA.

20
n_coas int

Number of distinct COAs to generate (1–7).

5
seed Optional[int]

Base random seed.

None
strategies Optional[Sequence[str]]

Explicit ordered list of strategy labels to evaluate.

None
env_kwargs Optional[dict]

Keyword arguments forwarded to :class:BattalionEnv when env is None.

None

Returns:

Type Description
list of :class:`CourseOfAction`, best first.
Source code in analysis/coa_generator.py
def generate_coas(
    env: Optional[BattalionEnv] = None,
    policy: Optional[Any] = None,
    n_rollouts: int = 20,
    n_coas: int = 5,
    seed: Optional[int] = None,
    strategies: Optional[Sequence[str]] = None,
    env_kwargs: Optional[dict] = None,
) -> List[CourseOfAction]:
    """Generate COAs, optionally creating a temporary environment.

    This is a thin convenience wrapper around :class:`COAGenerator`.

    Parameters
    ----------
    env:
        An existing :class:`~envs.battalion_env.BattalionEnv`.  When
        ``None`` a new environment is created using *env_kwargs* and
        closed automatically when generation completes.
    policy:
        Optional trained policy (see :meth:`COAGenerator.generate`).
    n_rollouts:
        Number of Monte-Carlo rollouts per COA.
    n_coas:
        Number of distinct COAs to generate (1–7).
    seed:
        Base random seed.
    strategies:
        Explicit ordered list of strategy labels to evaluate.
    env_kwargs:
        Keyword arguments forwarded to :class:`BattalionEnv` when *env*
        is ``None``.

    Returns
    -------
    list of :class:`CourseOfAction`, best first.
    """
    owns_env = env is None
    active_env = (
        env
        if env is not None
        else BattalionEnv(**(env_kwargs or {}))
    )
    try:
        generator = COAGenerator(
            env=active_env,
            n_rollouts=n_rollouts,
            n_coas=n_coas,
            seed=seed,
            strategies=strategies,
        )
        return generator.generate(policy=policy)
    finally:
        if owns_env:
            active_env.close()

analysis.coa_generator.generate_corps_coas(env=None, policy=None, n_rollouts=10, n_coas=10, seed=None, strategies=None, env_kwargs=None, explain=False)

Generate corps-level COAs, optionally creating a temporary CorpsEnv.

Parameters:

Name Type Description Default
env Optional[Any]

An existing :class:~envs.corps_env.CorpsEnv. When None a new environment is created using env_kwargs and closed automatically.

None
policy Optional[Any]

Optional trained policy.

None
n_rollouts int

Monte-Carlo rollouts per COA (default 10). 10 COAs × 10 rollouts runs comfortably within the 120 s budget on CPU.

10
n_coas int

Number of COAs to generate (1–10, default 10).

10
seed Optional[int]

Base random seed.

None
strategies Optional[Sequence[str]]

Explicit ordered list of corps strategy labels to evaluate.

None
env_kwargs Optional[dict]

Keyword arguments forwarded to :class:~envs.corps_env.CorpsEnv when env is None.

None
explain bool

If True, populate the explanation field of each :class:CorpsCourseOfAction via :meth:CorpsCOAGenerator.explain_coa.

False

Returns:

Type Description
list of :class:`CorpsCourseOfAction`, best first.
Source code in analysis/coa_generator.py
def generate_corps_coas(
    env: Optional[Any] = None,
    policy: Optional[Any] = None,
    n_rollouts: int = 10,
    n_coas: int = 10,
    seed: Optional[int] = None,
    strategies: Optional[Sequence[str]] = None,
    env_kwargs: Optional[dict] = None,
    explain: bool = False,
) -> List[CorpsCourseOfAction]:
    """Generate corps-level COAs, optionally creating a temporary CorpsEnv.

    Parameters
    ----------
    env:
        An existing :class:`~envs.corps_env.CorpsEnv`.  When ``None`` a new
        environment is created using *env_kwargs* and closed automatically.
    policy:
        Optional trained policy.
    n_rollouts:
        Monte-Carlo rollouts per COA (default 10).  10 COAs × 10 rollouts
        runs comfortably within the 120 s budget on CPU.
    n_coas:
        Number of COAs to generate (1–10, default 10).
    seed:
        Base random seed.
    strategies:
        Explicit ordered list of corps strategy labels to evaluate.
    env_kwargs:
        Keyword arguments forwarded to :class:`~envs.corps_env.CorpsEnv`
        when *env* is ``None``.
    explain:
        If ``True``, populate the ``explanation`` field of each
        :class:`CorpsCourseOfAction` via :meth:`CorpsCOAGenerator.explain_coa`.

    Returns
    -------
    list of :class:`CorpsCourseOfAction`, best first.
    """
    from envs.corps_env import CorpsEnv

    owns_env = env is None
    active_env: Any = (
        env
        if env is not None
        else CorpsEnv(**(env_kwargs or {}))
    )
    try:
        generator = CorpsCOAGenerator(
            env=active_env,
            n_rollouts=n_rollouts,
            n_coas=n_coas,
            seed=seed,
            strategies=strategies,
        )
        coa_list = generator.generate(policy=policy)
        if explain:
            for coa in coa_list:
                coa.explanation = generator.explain_coa(coa)
        return coa_list
    finally:
        if owns_env:
            active_env.close()

Saliency analysis

analysis.saliency.SaliencyAnalyzer

Convenience wrapper that bundles all explainability methods.

Parameters:

Name Type Description Default
policy Any

Trained policy. Accepts an SB3 PPO model, an ActorCriticPolicy, a :class:~models.mappo_policy.MAPPOPolicy, or any plain nn.Module.

required
feature_names Optional[Tuple[str, ...]]

Override the default :data:OBSERVATION_FEATURES labels.

None

Examples:

::

analyzer = SaliencyAnalyzer(ppo_model)
obs, _ = env.reset(seed=0)

sal  = analyzer.gradient_saliency(obs)
ig   = analyzer.integrated_gradients(obs)
shap = analyzer.shap_importance(obs)

print(analyzer.top_features(sal, k=3))
fig = analyzer.plot_saliency(sal)
Source code in analysis/saliency.py
class SaliencyAnalyzer:
    """Convenience wrapper that bundles all explainability methods.

    Parameters
    ----------
    policy:
        Trained policy.  Accepts an SB3 ``PPO`` model, an
        ``ActorCriticPolicy``, a :class:`~models.mappo_policy.MAPPOPolicy`,
        or any plain ``nn.Module``.
    feature_names:
        Override the default :data:`OBSERVATION_FEATURES` labels.

    Examples
    --------
    ::

        analyzer = SaliencyAnalyzer(ppo_model)
        obs, _ = env.reset(seed=0)

        sal  = analyzer.gradient_saliency(obs)
        ig   = analyzer.integrated_gradients(obs)
        shap = analyzer.shap_importance(obs)

        print(analyzer.top_features(sal, k=3))
        fig = analyzer.plot_saliency(sal)
    """

    def __init__(
        self,
        policy: Any,
        feature_names: Optional[Tuple[str, ...]] = None,
    ) -> None:
        self._policy = policy
        self.feature_names: Tuple[str, ...] = (
            feature_names if feature_names is not None else OBSERVATION_FEATURES
        )

    # ── Core methods ──────────────────────────────────────────────────────

    def gradient_saliency(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        *,
        reduce: str = "mean_abs",
    ) -> np.ndarray:
        """Return gradient saliency scores.  See :func:`compute_gradient_saliency`."""
        return compute_gradient_saliency(self._policy, obs, reduce=reduce)

    def integrated_gradients(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        *,
        baseline: Optional[Union[np.ndarray, torch.Tensor]] = None,
        n_steps: int = 50,
    ) -> np.ndarray:
        """Return integrated gradient attributions.  See :func:`compute_integrated_gradients`."""
        return compute_integrated_gradients(
            self._policy, obs, baseline=baseline, n_steps=n_steps
        )

    def shap_importance(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        *,
        background: Optional[Union[np.ndarray, torch.Tensor]] = None,
        n_samples: int = 100,
    ) -> np.ndarray:
        """Return SHAP feature importance.  See :func:`compute_shap_importance`."""
        return compute_shap_importance(
            self._policy, obs, background=background, n_samples=n_samples
        )

    # ── Utilities ─────────────────────────────────────────────────────────

    def top_features(
        self,
        scores: np.ndarray,
        k: int = 3,
    ) -> List[Tuple[str, float]]:
        """Return the top-*k* feature names and their scores, sorted descending.

        Parameters
        ----------
        scores:
            1-D importance / saliency array.
        k:
            Number of top features to return.

        Returns
        -------
        list of (feature_name, score) tuples
        """
        order = np.argsort(scores)[::-1][:k]
        return [(self.feature_names[i], float(scores[i])) for i in order]

    def summary(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        *,
        n_steps: int = 20,
        n_samples: int = 50,
    ) -> Dict[str, np.ndarray]:
        """Compute all three importance metrics and return them as a dict.

        Keys: ``"gradient_saliency"``, ``"integrated_gradients"``,
        ``"shap_importance"``.
        """
        return {
            "gradient_saliency": self.gradient_saliency(obs),
            "integrated_gradients": self.integrated_gradients(obs, n_steps=n_steps),
            "shap_importance": self.shap_importance(obs, n_samples=n_samples),
        }

    # ── Plot helpers ──────────────────────────────────────────────────────

    def plot_saliency(
        self,
        saliency: Optional[np.ndarray] = None,
        obs: Optional[Union[np.ndarray, torch.Tensor]] = None,
        *,
        title: str = "Gradient Saliency",
        **kwargs: Any,
    ) -> "matplotlib.figure.Figure":  # type: ignore[name-defined]
        """Plot saliency map.

        Either pass pre-computed ``saliency`` scores, or pass ``obs`` directly
        to compute them automatically.
        """
        if saliency is None:
            if obs is None:
                raise ValueError("Either saliency or obs must be provided.")
            saliency = self.gradient_saliency(obs)
        return plot_saliency_map(
            saliency,
            feature_names=self.feature_names,
            title=title,
            **kwargs,
        )

    def plot_importance(
        self,
        importances: Optional[np.ndarray] = None,
        obs: Optional[Union[np.ndarray, torch.Tensor]] = None,
        *,
        title: str = "Feature Importance (SHAP)",
        **kwargs: Any,
    ) -> "matplotlib.figure.Figure":  # type: ignore[name-defined]
        """Plot feature importance bar chart.

        Either pass pre-computed ``importances`` scores, or pass ``obs``
        directly to compute SHAP importance automatically.
        """
        if importances is None:
            if obs is None:
                raise ValueError("Either importances or obs must be provided.")
            importances = self.shap_importance(obs)
        return plot_feature_importance(
            importances,
            feature_names=self.feature_names,
            title=title,
            **kwargs,
        )

gradient_saliency(obs, *, reduce='mean_abs')

Return gradient saliency scores. See :func:compute_gradient_saliency.

Source code in analysis/saliency.py
def gradient_saliency(
    self,
    obs: Union[np.ndarray, torch.Tensor],
    *,
    reduce: str = "mean_abs",
) -> np.ndarray:
    """Return gradient saliency scores.  See :func:`compute_gradient_saliency`."""
    return compute_gradient_saliency(self._policy, obs, reduce=reduce)

integrated_gradients(obs, *, baseline=None, n_steps=50)

Return integrated gradient attributions. See :func:compute_integrated_gradients.

Source code in analysis/saliency.py
def integrated_gradients(
    self,
    obs: Union[np.ndarray, torch.Tensor],
    *,
    baseline: Optional[Union[np.ndarray, torch.Tensor]] = None,
    n_steps: int = 50,
) -> np.ndarray:
    """Return integrated gradient attributions.  See :func:`compute_integrated_gradients`."""
    return compute_integrated_gradients(
        self._policy, obs, baseline=baseline, n_steps=n_steps
    )

plot_importance(importances=None, obs=None, *, title='Feature Importance (SHAP)', **kwargs)

Plot feature importance bar chart.

Either pass pre-computed importances scores, or pass obs directly to compute SHAP importance automatically.

Source code in analysis/saliency.py
def plot_importance(
    self,
    importances: Optional[np.ndarray] = None,
    obs: Optional[Union[np.ndarray, torch.Tensor]] = None,
    *,
    title: str = "Feature Importance (SHAP)",
    **kwargs: Any,
) -> "matplotlib.figure.Figure":  # type: ignore[name-defined]
    """Plot feature importance bar chart.

    Either pass pre-computed ``importances`` scores, or pass ``obs``
    directly to compute SHAP importance automatically.
    """
    if importances is None:
        if obs is None:
            raise ValueError("Either importances or obs must be provided.")
        importances = self.shap_importance(obs)
    return plot_feature_importance(
        importances,
        feature_names=self.feature_names,
        title=title,
        **kwargs,
    )

plot_saliency(saliency=None, obs=None, *, title='Gradient Saliency', **kwargs)

Plot saliency map.

Either pass pre-computed saliency scores, or pass obs directly to compute them automatically.

Source code in analysis/saliency.py
def plot_saliency(
    self,
    saliency: Optional[np.ndarray] = None,
    obs: Optional[Union[np.ndarray, torch.Tensor]] = None,
    *,
    title: str = "Gradient Saliency",
    **kwargs: Any,
) -> "matplotlib.figure.Figure":  # type: ignore[name-defined]
    """Plot saliency map.

    Either pass pre-computed ``saliency`` scores, or pass ``obs`` directly
    to compute them automatically.
    """
    if saliency is None:
        if obs is None:
            raise ValueError("Either saliency or obs must be provided.")
        saliency = self.gradient_saliency(obs)
    return plot_saliency_map(
        saliency,
        feature_names=self.feature_names,
        title=title,
        **kwargs,
    )

shap_importance(obs, *, background=None, n_samples=100)

Return SHAP feature importance. See :func:compute_shap_importance.

Source code in analysis/saliency.py
def shap_importance(
    self,
    obs: Union[np.ndarray, torch.Tensor],
    *,
    background: Optional[Union[np.ndarray, torch.Tensor]] = None,
    n_samples: int = 100,
) -> np.ndarray:
    """Return SHAP feature importance.  See :func:`compute_shap_importance`."""
    return compute_shap_importance(
        self._policy, obs, background=background, n_samples=n_samples
    )

summary(obs, *, n_steps=20, n_samples=50)

Compute all three importance metrics and return them as a dict.

Keys: "gradient_saliency", "integrated_gradients", "shap_importance".

Source code in analysis/saliency.py
def summary(
    self,
    obs: Union[np.ndarray, torch.Tensor],
    *,
    n_steps: int = 20,
    n_samples: int = 50,
) -> Dict[str, np.ndarray]:
    """Compute all three importance metrics and return them as a dict.

    Keys: ``"gradient_saliency"``, ``"integrated_gradients"``,
    ``"shap_importance"``.
    """
    return {
        "gradient_saliency": self.gradient_saliency(obs),
        "integrated_gradients": self.integrated_gradients(obs, n_steps=n_steps),
        "shap_importance": self.shap_importance(obs, n_samples=n_samples),
    }

top_features(scores, k=3)

Return the top-k feature names and their scores, sorted descending.

Parameters:

Name Type Description Default
scores ndarray

1-D importance / saliency array.

required
k int

Number of top features to return.

3

Returns:

Type Description
list of (feature_name, score) tuples
Source code in analysis/saliency.py
def top_features(
    self,
    scores: np.ndarray,
    k: int = 3,
) -> List[Tuple[str, float]]:
    """Return the top-*k* feature names and their scores, sorted descending.

    Parameters
    ----------
    scores:
        1-D importance / saliency array.
    k:
        Number of top features to return.

    Returns
    -------
    list of (feature_name, score) tuples
    """
    order = np.argsort(scores)[::-1][:k]
    return [(self.feature_names[i], float(scores[i])) for i in order]

analysis.saliency.compute_gradient_saliency(policy, obs, *, reduce='mean_abs')

Compute gradient-based saliency scores for each observation dimension.

Back-propagates through the policy network and uses the absolute gradient of the summed action output with respect to each input dimension as an importance proxy.

Parameters:

Name Type Description Default
policy Any

Trained policy. Accepts SB3 PPO, ActorCriticPolicy, or a plain nn.Module.

required
obs Union[ndarray, Tensor]

Observation array of shape (obs_dim,) or (N, obs_dim).

required
reduce str

How to aggregate across the action dimension and batch: - "mean_abs" (default) — mean of absolute gradients. - "max_abs" — max of absolute gradients. - "sum_abs" — sum of absolute gradients.

'mean_abs'

Returns:

Type Description
ndarray

Saliency scores of shape (obs_dim,) (after batch aggregation), or (N, obs_dim) if reduce="none".

Source code in analysis/saliency.py
def compute_gradient_saliency(
    policy: Any,
    obs: Union[np.ndarray, torch.Tensor],
    *,
    reduce: str = "mean_abs",
) -> np.ndarray:
    """Compute gradient-based saliency scores for each observation dimension.

    Back-propagates through the policy network and uses the absolute gradient
    of the summed action output with respect to each input dimension as an
    importance proxy.

    Parameters
    ----------
    policy:
        Trained policy.  Accepts SB3 ``PPO``, ``ActorCriticPolicy``, or a
        plain ``nn.Module``.
    obs:
        Observation array of shape ``(obs_dim,)`` or ``(N, obs_dim)``.
    reduce:
        How to aggregate across the action dimension and batch:
        - ``"mean_abs"`` (default) — mean of absolute gradients.
        - ``"max_abs"`` — max of absolute gradients.
        - ``"sum_abs"`` — sum of absolute gradients.

    Returns
    -------
    np.ndarray
        Saliency scores of shape ``(obs_dim,)`` (after batch aggregation),
        or ``(N, obs_dim)`` if ``reduce="none"``.
    """
    net = _extract_mlp_network(policy)
    net.eval()

    device = next(net.parameters(), torch.tensor(0)).device
    x = _to_tensor(obs, device=device)
    x = x.requires_grad_(True)

    output = net(x)
    # Scalar target: sum of absolute action means over all actions and batch
    target = output.sum()
    target.backward()

    grad = x.grad.detach().cpu().numpy()  # (N, obs_dim)

    if reduce == "none":
        return np.abs(grad)
    elif reduce == "mean_abs":
        return np.abs(grad).mean(axis=0)
    elif reduce == "max_abs":
        return np.abs(grad).max(axis=0)
    elif reduce == "sum_abs":
        return np.abs(grad).sum(axis=0)
    else:
        raise ValueError(f"Unknown reduce mode: {reduce!r}")

analysis.saliency.compute_integrated_gradients(policy, obs, *, baseline=None, n_steps=50)

Compute integrated gradient attributions.

Follows the method of Sundararajan et al. (2017): accumulate gradients along the straight-line path from baseline to obs, then multiply element-wise by (obs - baseline). The result satisfies the completeness axiom: attributions sum to f(obs) - f(baseline).

Parameters:

Name Type Description Default
policy Any

Trained policy (same supported types as :func:compute_gradient_saliency).

required
obs Union[ndarray, Tensor]

Single observation of shape (obs_dim,) or batch (N, obs_dim).

required
baseline Optional[Union[ndarray, Tensor]]

Reference point. Defaults to the all-zeros observation.

None
n_steps int

Number of interpolation steps along the path (higher → more accurate). Must be >= 1.

50

Returns:

Type Description
ndarray

Attribution scores of shape (obs_dim,) (mean over batch if N>1).

Source code in analysis/saliency.py
def compute_integrated_gradients(
    policy: Any,
    obs: Union[np.ndarray, torch.Tensor],
    *,
    baseline: Optional[Union[np.ndarray, torch.Tensor]] = None,
    n_steps: int = 50,
) -> np.ndarray:
    """Compute integrated gradient attributions.

    Follows the method of Sundararajan et al. (2017): accumulate gradients
    along the straight-line path from ``baseline`` to ``obs``, then multiply
    element-wise by ``(obs - baseline)``.  The result satisfies the
    *completeness* axiom: attributions sum to ``f(obs) - f(baseline)``.

    Parameters
    ----------
    policy:
        Trained policy (same supported types as :func:`compute_gradient_saliency`).
    obs:
        Single observation of shape ``(obs_dim,)`` or batch ``(N, obs_dim)``.
    baseline:
        Reference point.  Defaults to the all-zeros observation.
    n_steps:
        Number of interpolation steps along the path (higher → more accurate).
        Must be >= 1.

    Returns
    -------
    np.ndarray
        Attribution scores of shape ``(obs_dim,)`` (mean over batch if N>1).
    """
    if n_steps < 1:
        raise ValueError(f"n_steps must be >= 1, got {n_steps}")

    net = _extract_mlp_network(policy)
    net.eval()

    device = next(net.parameters(), torch.tensor(0)).device
    x = _to_tensor(obs, device=device)  # (N, obs_dim)
    N, obs_dim = x.shape

    if baseline is None:
        base = torch.zeros_like(x)
    else:
        base = _to_tensor(baseline, device=device)
        if base.shape != x.shape:
            base = base.expand_as(x)

    # Accumulate gradients along the interpolation path.
    # Use torch.autograd.grad so that model parameter .grad buffers are never
    # touched — this is safe to call during/after training without side effects.
    accumulated_grads = torch.zeros_like(x)

    for alpha in np.linspace(0.0, 1.0, n_steps):
        x_interp = (base + alpha * (x - base)).detach().requires_grad_(True)
        output = net(x_interp).sum()
        (grad,) = torch.autograd.grad(output, x_interp)
        accumulated_grads = accumulated_grads + grad.detach()

    # IG formula: (obs - baseline) * mean_gradient
    ig = ((x - base) * accumulated_grads / n_steps).detach().cpu().numpy()  # (N, obs_dim)
    return ig.mean(axis=0)  # (obs_dim,)

analysis.saliency.compute_shap_importance(policy, obs, *, background=None, n_samples=100)

Compute SHAP-based feature importance scores.

Attempts to use the shap library (GradientExplainer for differentiable models). When shap is not installed or the model is not compatible, falls back to a lightweight permutation importance approximation that estimates each feature's marginal contribution by masking it with the background mean.

Parameters:

Name Type Description Default
policy Any

Trained policy.

required
obs Union[ndarray, Tensor]

Observations of shape (obs_dim,) or (N, obs_dim).

required
background Optional[Union[ndarray, Tensor]]

Background dataset for SHAP / permutation baseline. Defaults to the all-zeros observation (single sample).

None
n_samples int

Number of samples used in the permutation fallback.

100

Returns:

Type Description
ndarray

Absolute SHAP values / permutation importances of shape (obs_dim,).

Source code in analysis/saliency.py
def compute_shap_importance(
    policy: Any,
    obs: Union[np.ndarray, torch.Tensor],
    *,
    background: Optional[Union[np.ndarray, torch.Tensor]] = None,
    n_samples: int = 100,
) -> np.ndarray:
    """Compute SHAP-based feature importance scores.

    Attempts to use the ``shap`` library (GradientExplainer for differentiable
    models).  When ``shap`` is not installed or the model is not compatible,
    falls back to a lightweight **permutation importance** approximation that
    estimates each feature's marginal contribution by masking it with the
    background mean.

    Parameters
    ----------
    policy:
        Trained policy.
    obs:
        Observations of shape ``(obs_dim,)`` or ``(N, obs_dim)``.
    background:
        Background dataset for SHAP / permutation baseline.  Defaults to
        the all-zeros observation (single sample).
    n_samples:
        Number of samples used in the permutation fallback.

    Returns
    -------
    np.ndarray
        Absolute SHAP values / permutation importances of shape ``(obs_dim,)``.
    """
    net = _extract_mlp_network(policy)
    net.eval()

    device = next(net.parameters(), torch.tensor(0)).device
    x_np = _to_tensor(obs, device=device).detach().cpu().numpy()  # (N, obs_dim)

    if background is None:
        bg_np = np.zeros((1, x_np.shape[1]), dtype=np.float32)
    else:
        bg_np = _to_tensor(background, device=device).detach().cpu().numpy()

    # ── Try shap library ──────────────────────────────────────────────────
    shap = None
    try:
        import shap as _shap  # type: ignore
        shap = _shap
    except ImportError:
        pass  # fall through to permutation fallback without a warning

    if shap is not None:
        try:
            bg_tensor = _to_tensor(bg_np, device=device)
            explainer = shap.GradientExplainer(net, bg_tensor)
            shap_values = explainer.shap_values(_to_tensor(x_np, device=device))
            # shap_values may be a list (one per output) or an array
            if isinstance(shap_values, list):
                shap_arr = np.mean([np.abs(sv) for sv in shap_values], axis=0)
            else:
                shap_arr = np.abs(np.array(shap_values))
            return shap_arr.mean(axis=0)  # (obs_dim,)
        except Exception as exc:  # noqa: BLE001
            warnings.warn(
                f"Falling back to permutation importance because SHAP computation "
                f"failed with: {exc!r}",
                RuntimeWarning,
                stacklevel=2,
            )

    # ── Permutation-based fallback ────────────────────────────────────────
    return _permutation_importance(net, x_np, bg_np, n_samples=n_samples)

analysis.saliency.plot_saliency_map(saliency, *, feature_names=None, title='Gradient Saliency', normalise=True, figsize=(9.0, 4.0))

Render a horizontal bar chart of saliency scores.

Parameters:

Name Type Description Default
saliency ndarray

1-D array of shape (obs_dim,) with non-negative saliency scores.

required
feature_names Optional[Tuple[str, ...]]

Feature labels for each dimension. Defaults to :data:OBSERVATION_FEATURES.

None
title str

Plot title string.

'Gradient Saliency'
normalise bool

When True (default), divide by the maximum value so all bars lie in [0, 1].

True
figsize Tuple[float, float]

Matplotlib figure size (width, height) in inches.

(9.0, 4.0)

Returns:

Type Description
Figure
Source code in analysis/saliency.py
def plot_saliency_map(
    saliency: np.ndarray,
    *,
    feature_names: Optional[Tuple[str, ...]] = None,
    title: str = "Gradient Saliency",
    normalise: bool = True,
    figsize: Tuple[float, float] = (9.0, 4.0),
) -> "matplotlib.figure.Figure":  # type: ignore[name-defined]
    """Render a horizontal bar chart of saliency scores.

    Parameters
    ----------
    saliency:
        1-D array of shape ``(obs_dim,)`` with non-negative saliency scores.
    feature_names:
        Feature labels for each dimension.  Defaults to
        :data:`OBSERVATION_FEATURES`.
    title:
        Plot title string.
    normalise:
        When ``True`` (default), divide by the maximum value so all bars lie
        in [0, 1].
    figsize:
        Matplotlib figure size ``(width, height)`` in inches.

    Returns
    -------
    matplotlib.figure.Figure
    """
    import matplotlib.pyplot as plt

    if feature_names is None:
        feature_names = OBSERVATION_FEATURES

    scores = np.array(saliency, dtype=np.float64)
    if normalise and scores.max() > 0:
        scores = scores / scores.max()

    n = len(scores)
    y = np.arange(n)

    fig, ax = plt.subplots(figsize=figsize)
    bars = ax.barh(y, scores, color="steelblue", edgecolor="white")
    ax.set_yticks(y)
    ax.set_yticklabels(feature_names[:n], fontsize=9)
    ax.set_xlabel("Normalised saliency" if normalise else "Saliency")
    ax.set_title(title)
    ax.invert_yaxis()  # highest at top

    # Annotate values
    for bar, val in zip(bars, scores):
        ax.text(
            min(val + 0.01, 0.95),
            bar.get_y() + bar.get_height() / 2,
            f"{val:.3f}",
            va="center",
            ha="left",
            fontsize=7,
            color="black",
        )

    fig.tight_layout()
    return fig

analysis.saliency.plot_feature_importance(importances, *, feature_names=None, title='Feature Importance', top_k=12, figsize=(9.0, 4.5))

Render a horizontal bar chart of feature importances, sorted by value.

Parameters:

Name Type Description Default
importances ndarray

1-D importance array of shape (obs_dim,).

required
feature_names Optional[Tuple[str, ...]]

Feature labels. Defaults to :data:OBSERVATION_FEATURES.

None
title str

Plot title.

'Feature Importance'
top_k int

Show at most this many features (sorted descending).

12
figsize Tuple[float, float]

Matplotlib figure size.

(9.0, 4.5)

Returns:

Type Description
Figure
Source code in analysis/saliency.py
def plot_feature_importance(
    importances: np.ndarray,
    *,
    feature_names: Optional[Tuple[str, ...]] = None,
    title: str = "Feature Importance",
    top_k: int = 12,
    figsize: Tuple[float, float] = (9.0, 4.5),
) -> "matplotlib.figure.Figure":  # type: ignore[name-defined]
    """Render a horizontal bar chart of feature importances, sorted by value.

    Parameters
    ----------
    importances:
        1-D importance array of shape ``(obs_dim,)``.
    feature_names:
        Feature labels.  Defaults to :data:`OBSERVATION_FEATURES`.
    title:
        Plot title.
    top_k:
        Show at most this many features (sorted descending).
    figsize:
        Matplotlib figure size.

    Returns
    -------
    matplotlib.figure.Figure
    """
    import matplotlib.pyplot as plt

    if feature_names is None:
        feature_names = OBSERVATION_FEATURES

    scores = np.array(importances, dtype=np.float64)
    n = min(top_k, len(scores))

    # Sort descending
    order = np.argsort(scores)[::-1][:n]
    sorted_scores = scores[order]
    sorted_labels = [feature_names[i] for i in order]

    fig, ax = plt.subplots(figsize=figsize)
    y = np.arange(n)
    bars = ax.barh(y, sorted_scores, color="darkorange", edgecolor="white")
    ax.set_yticks(y)
    ax.set_yticklabels(sorted_labels, fontsize=9)
    ax.set_xlabel("Importance score")
    ax.set_title(title)
    ax.invert_yaxis()

    for bar, val in zip(bars, sorted_scores):
        ax.text(
            bar.get_width() * 1.01,
            bar.get_y() + bar.get_height() / 2,
            f"{val:.4f}",
            va="center",
            ha="left",
            fontsize=7,
            color="black",
        )

    fig.tight_layout()
    return fig