Training API¶
The training package exposes the full training workflow as a clean
programmatic Python API. All stable symbols are importable from the
top-level training namespace.
Quick-start¶
from training import train, TrainingConfig
# Minimal run with defaults
model = train(total_timesteps=500_000, n_envs=4, enable_wandb=False)
# Full config object
config = TrainingConfig(
total_timesteps=1_000_000,
n_envs=8,
curriculum_level=3,
enable_self_play=True,
wandb_project="my_project",
)
model = train(config)
Training runners¶
training.train.TrainingConfig
dataclass
¶
Configuration for a single PPO training run on :class:~envs.battalion_env.BattalionEnv.
All fields are optional; the defaults match configs/default.yaml.
Instances can be passed directly to :func:train or individual fields
can be overridden via **kwargs at the call site.
Examples:
Minimal run with defaults::
from training import train, TrainingConfig
model = train(TrainingConfig(total_timesteps=500_000))
Override specific fields at call time without constructing a config::
model = train(total_timesteps=200_000, n_envs=4, enable_wandb=False)
Source code in training/train.py
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 | |
training.train.train(config=None, *, extra_callbacks=None, resume=None, **override_kwargs)
¶
Train a PPO policy on :class:~envs.battalion_env.BattalionEnv.
This is the programmatic entry-point for training, fully decoupled from Hydra/YAML so it can be called from any Python script or notebook.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
Optional[TrainingConfig]
|
:class: |
None
|
extra_callbacks
|
Optional[List[BaseCallback]]
|
Additional SB3 :class: |
None
|
resume
|
Optional[Union[str, Path]]
|
Path to an existing |
None
|
**override_kwargs
|
Any
|
Keyword arguments that override individual :class: |
{}
|
Returns:
| Type | Description |
|---|---|
PPO
|
The trained model. Periodic checkpoints, best-model, and a manifest are written to config.checkpoint_dir during training. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any configuration value is invalid or an unrecognised
|
FileNotFoundError
|
If resume points to a path that does not exist on disk. |
Examples:
Quickstart with defaults::
from training import train
model = train(total_timesteps=200_000, n_envs=4, enable_wandb=False)
Full config::
from training import train, TrainingConfig
config = TrainingConfig(
total_timesteps=1_000_000,
n_envs=8,
curriculum_level=3,
enable_self_play=True,
)
model = train(config)
Source code in training/train.py
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 | |
Callbacks¶
training.train.WandbCallback
¶
Bases: BaseCallback
Logs SB3 training metrics to an active W&B run.
Emits episode-level rollout statistics (mean reward and episode length)
every log_freq environment steps, and policy-update losses (if
available from the SB3 logger) at the end of each rollout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_freq
|
int
|
How often (in environment steps) to log rollout statistics. |
1000
|
verbose
|
int
|
Verbosity level (0 = silent, 1 = info). |
0
|
Source code in training/train.py
training.train.RewardBreakdownCallback
¶
Bases: BaseCallback
Logs per-component reward breakdown to W&B at episode boundaries.
Accumulates reward components from info dicts (populated by
:class:~envs.battalion_env.BattalionEnv) across all parallel
environments every step and rolls them into per-episode totals when
an episode ends. The episode means are logged to W&B every
log_freq timesteps. Any remaining episodes at the end of
training are flushed in _on_training_end().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_freq
|
int
|
How often (in environment steps) to flush accumulated episode means to W&B. |
1000
|
verbose
|
int
|
Verbosity level (0 = silent, 1 = info). |
0
|
Source code in training/train.py
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 | |
training.train.EloEvalCallback
¶
Bases: BaseCallback
Evaluate the current policy vs scripted opponents and log Elo to W&B.
Every eval_freq environment steps the callback runs n_eval_episodes
episodes against each opponent in opponents using the live model,
updates the :class:~training.elo.EloRegistry, persists it to disk, and
logs per-opponent Elo ratings and win rates to W&B.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
opponents
|
list[str]
|
List of opponent identifiers (e.g. |
required |
n_eval_episodes
|
int
|
Number of episodes to run per opponent per evaluation. |
required |
registry
|
EloRegistry
|
:class: |
required |
agent_name
|
str
|
Key used to identify this training run in the registry. |
required |
eval_freq
|
int
|
How often (in environment steps) to trigger evaluation. |
required |
env_kwargs
|
Optional[dict]
|
Keyword arguments forwarded to :class: |
None
|
seed
|
Optional[int]
|
Base random seed for evaluation episodes. |
None
|
verbose
|
int
|
Verbosity level (0 = silent, 1 = info). |
0
|
Source code in training/train.py
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 | |
training.train.ManifestCheckpointCallback
¶
Bases: CheckpointCallback
Checkpoint callback that appends periodic saves to the manifest immediately.
Source code in training/train.py
training.train.ManifestEvalCallback
¶
Bases: EvalCallback
Eval callback that materializes best-model metadata at creation time.
Source code in training/train.py
Evaluation¶
training.evaluate
¶
Evaluate a saved PPO checkpoint against a configurable opponent.
Loads a Stable-Baselines3 PPO model from a .zip checkpoint, runs it
against a chosen opponent in :class:~envs.battalion_env.BattalionEnv for a
configurable number of episodes, and reports the Blue win rate and optional
Elo delta to stdout.
Supported opponent identifiers
scripted_l1 … scripted_l5
Built-in scripted Red opponent at the specified curriculum level.
random
A Red opponent that samples uniformly random actions every step.
<path>
Any file-system path to an SB3 .zip checkpoint; that model drives Red.
A win is defined as Red routing or being destroyed without Blue having routed or been destroyed in the same step. A draw occurs when both sides lose simultaneously or the episode reaches the step limit with neither side eliminated.
Usage::
python training/evaluate.py --checkpoint checkpoints/run/final \
--opponent scripted_l3
python training/evaluate.py --checkpoint checkpoints/run/final \
--opponent scripted_l3 --n-episodes 100 --seed 0 \
--elo-registry checkpoints/elo_registry.json \
--agent-name my_run_v1
EvaluationResult
¶
Bases: NamedTuple
Structured result from an evaluation run.
Attributes:
| Name | Type | Description |
|---|---|---|
wins |
int
|
Number of episodes Blue won. |
draws |
int
|
Number of episodes that ended as a draw (both sides lost or timeout). |
losses |
int
|
Number of episodes Blue lost. |
n_episodes |
int
|
Total episodes evaluated ( |
win_rate |
float
|
|
draw_rate |
float
|
|
loss_rate |
float
|
|
Source code in training/evaluate.py
evaluate(checkpoint_path, n_episodes=50, deterministic=True, seed=None, opponent='scripted_l5')
¶
Load a checkpoint, run n_episodes, and return the Blue win rate.
This is a thin wrapper around :func:evaluate_detailed kept for
backward compatibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_path
|
str
|
Path to the |
required |
n_episodes
|
int
|
Number of evaluation episodes (must be ≥ 1). |
50
|
deterministic
|
bool
|
Whether the policy acts deterministically. |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
opponent
|
str
|
Opponent identifier — see module docstring for valid values.
Defaults to |
'scripted_l5'
|
Returns:
| Type | Description |
|---|---|
float
|
Win rate in |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes is less than 1. |
Source code in training/evaluate.py
evaluate_detailed(checkpoint_path, n_episodes=50, deterministic=True, seed=None, opponent='scripted_l5', env_kwargs=None)
¶
Load a checkpoint and run n_episodes, returning a full result struct.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_path
|
str
|
Path to the |
required |
n_episodes
|
int
|
Number of evaluation episodes (must be ≥ 1). |
50
|
deterministic
|
bool
|
Whether the policy acts deterministically. |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
opponent
|
str
|
Opponent identifier — see module docstring for valid values. |
'scripted_l5'
|
env_kwargs
|
Optional[dict]
|
Extra keyword arguments forwarded to :class: |
None
|
Returns:
| Type | Description |
|---|---|
EvaluationResult
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes < 1. |
Source code in training/evaluate.py
main(argv=None)
¶
CLI entry point.
Source code in training/evaluate.py
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 | |
run_episodes_with_model(model, opponent='scripted_l5', n_episodes=50, deterministic=True, seed=None, env=None, env_kwargs=None)
¶
Run evaluation episodes using an already-loaded model object.
This is useful for in-training callbacks that have direct access to a
:class:~stable_baselines3.PPO model without needing to save and reload
a checkpoint file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
Any object with a |
required |
opponent
|
str
|
Opponent identifier — see module docstring for valid values. |
'scripted_l5'
|
n_episodes
|
int
|
Number of episodes to run (must be ≥ 1). |
50
|
deterministic
|
bool
|
Whether the policy acts deterministically. |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
env
|
Optional[BattalionEnv]
|
Pre-built :class: |
None
|
env_kwargs
|
Optional[dict]
|
Extra keyword arguments forwarded to :class: |
None
|
Returns:
| Type | Description |
|---|---|
EvaluationResult
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes < 1. |
Source code in training/evaluate.py
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 | |
training.evaluate.EvaluationResult
¶
Bases: NamedTuple
Structured result from an evaluation run.
Attributes:
| Name | Type | Description |
|---|---|---|
wins |
int
|
Number of episodes Blue won. |
draws |
int
|
Number of episodes that ended as a draw (both sides lost or timeout). |
losses |
int
|
Number of episodes Blue lost. |
n_episodes |
int
|
Total episodes evaluated ( |
win_rate |
float
|
|
draw_rate |
float
|
|
loss_rate |
float
|
|
Source code in training/evaluate.py
training.evaluate.evaluate(checkpoint_path, n_episodes=50, deterministic=True, seed=None, opponent='scripted_l5')
¶
Load a checkpoint, run n_episodes, and return the Blue win rate.
This is a thin wrapper around :func:evaluate_detailed kept for
backward compatibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_path
|
str
|
Path to the |
required |
n_episodes
|
int
|
Number of evaluation episodes (must be ≥ 1). |
50
|
deterministic
|
bool
|
Whether the policy acts deterministically. |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
opponent
|
str
|
Opponent identifier — see module docstring for valid values.
Defaults to |
'scripted_l5'
|
Returns:
| Type | Description |
|---|---|
float
|
Win rate in |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes is less than 1. |
Source code in training/evaluate.py
training.evaluate.evaluate_detailed(checkpoint_path, n_episodes=50, deterministic=True, seed=None, opponent='scripted_l5', env_kwargs=None)
¶
Load a checkpoint and run n_episodes, returning a full result struct.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint_path
|
str
|
Path to the |
required |
n_episodes
|
int
|
Number of evaluation episodes (must be ≥ 1). |
50
|
deterministic
|
bool
|
Whether the policy acts deterministically. |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
opponent
|
str
|
Opponent identifier — see module docstring for valid values. |
'scripted_l5'
|
env_kwargs
|
Optional[dict]
|
Extra keyword arguments forwarded to :class: |
None
|
Returns:
| Type | Description |
|---|---|
EvaluationResult
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes < 1. |
Source code in training/evaluate.py
training.evaluate.run_episodes_with_model(model, opponent='scripted_l5', n_episodes=50, deterministic=True, seed=None, env=None, env_kwargs=None)
¶
Run evaluation episodes using an already-loaded model object.
This is useful for in-training callbacks that have direct access to a
:class:~stable_baselines3.PPO model without needing to save and reload
a checkpoint file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
Any object with a |
required |
opponent
|
str
|
Opponent identifier — see module docstring for valid values. |
'scripted_l5'
|
n_episodes
|
int
|
Number of episodes to run (must be ≥ 1). |
50
|
deterministic
|
bool
|
Whether the policy acts deterministically. |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
env
|
Optional[BattalionEnv]
|
Pre-built :class: |
None
|
env_kwargs
|
Optional[dict]
|
Extra keyword arguments forwarded to :class: |
None
|
Returns:
| Type | Description |
|---|---|
EvaluationResult
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes < 1. |
Source code in training/evaluate.py
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 | |
Self-play¶
training.self_play
¶
Self-play training utilities.
Provides:
- :class:
OpponentPool— a fixed-size pool of frozen policy snapshots that can be sampled uniformly as opponents during self-play training. - :class:
SelfPlayCallback— SB3 callback that periodically snapshots the current policy into the pool and swaps the Red opponent in the vectorized training environment. - :class:
WinRateVsPoolCallback— SB3 callback that evaluates the current policy against a random opponent from the pool and logs the win rate to W&B. - :func:
evaluate_vs_pool— standalone helper that runs n evaluation episodes against an opponent sampled from the pool and returns the win rate.
Multi-agent (MAPPO) additions:
- :class:
TeamOpponentPool— fixed-size pool of frozen :class:~models.mappo_policy.MAPPOPolicysnapshots for team self-play. - :func:
evaluate_team_vs_pool— evaluate a MAPPO policy (Blue) against a frozen team opponent (Red) and return the win rate. - :func:
nash_exploitability_proxy— estimate exploitability asmax(opp_win_rates) − mean(opp_win_rates)across all pool members.
Typical usage::
from training.self_play import OpponentPool, SelfPlayCallback, WinRateVsPoolCallback
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from envs.battalion_env import BattalionEnv
pool = OpponentPool(pool_dir="checkpoints/pool", max_size=10)
env = make_vec_env(BattalionEnv, n_envs=8)
model = PPO("MlpPolicy", env)
sp_cb = SelfPlayCallback(pool=pool, snapshot_freq=50_000, vec_env=env)
wr_cb = WinRateVsPoolCallback(pool=pool, eval_freq=50_000)
model.learn(total_timesteps=1_000_000, callback=[sp_cb, wr_cb])
OpponentPool
¶
Fixed-size pool of frozen PPO policy snapshots.
Snapshots are stored as Stable-Baselines3 .zip files under
pool_dir. The pool keeps at most max_size snapshots; when full,
the oldest snapshot is evicted to make room for the newest one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool_dir
|
str | Path
|
Directory where snapshot |
required |
max_size
|
int
|
Maximum number of snapshots to retain (default 10). |
10
|
Attributes:
| Name | Type | Description |
|---|---|---|
pool_dir |
Path
|
Resolved path of the snapshot directory. |
max_size |
int
|
Maximum number of snapshots retained in the pool. |
Source code in training/self_play.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | |
size
property
¶
Current number of snapshots in the pool.
snapshot_paths
property
¶
Ordered list of snapshot paths (oldest first, read-only copy).
add(model, version)
¶
Save model as a new snapshot and add it to the pool.
If the pool already contains max_size snapshots the oldest is removed from disk before saving the new one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PPO
|
The current PPO model to snapshot. |
required |
version
|
int
|
Monotonically increasing version number embedded in the snapshot file name for traceability. |
required |
Returns:
| Type | Description |
|---|---|
Path
|
Path of the newly saved snapshot file (including |
Source code in training/self_play.py
sample(rng=None)
¶
Load and return a uniformly sampled snapshot as a PPO model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Optional[Generator]
|
Optional NumPy random generator for reproducible sampling.
When |
None
|
Returns:
| Type | Description |
|---|---|
PPO or None
|
A loaded PPO model, or |
Source code in training/self_play.py
sample_latest()
¶
Load and return the most recently added snapshot.
Returns:
| Type | Description |
|---|---|
PPO or None
|
The latest PPO model, or |
Source code in training/self_play.py
SelfPlayCallback
¶
Bases: BaseCallback
Periodically snapshots the current policy and updates the Red opponent.
Every snapshot_freq environment steps the current model is saved to
the :class:OpponentPool. If the pool contains at least one snapshot,
a uniformly sampled opponent is loaded and injected into each
environment in vec_env via :meth:~envs.battalion_env.BattalionEnv.set_red_policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool
|
OpponentPool
|
The :class: |
required |
snapshot_freq
|
int
|
How often (in environment steps) to take a snapshot. |
50000
|
vec_env
|
Optional[VecEnv]
|
The vectorized training environment whose Red opponents should be
updated. When |
None
|
verbose
|
int
|
Verbosity level (0 = silent, 1 = info). |
0
|
Source code in training/self_play.py
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | |
TeamOpponentPool
¶
Fixed-size pool of frozen :class:~models.mappo_policy.MAPPOPolicy snapshots.
Snapshots are stored as PyTorch .pt files under pool_dir. Each
file contains the policy state_dict plus the constructor kwargs
(obs_dim, action_dim, state_dim, n_agents,
share_parameters) needed to reconstruct the policy at load time.
The pool keeps at most max_size snapshots; when full, the oldest is
evicted to make room for the newest.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool_dir
|
str | Path
|
Directory where snapshot |
required |
max_size
|
int
|
Maximum number of snapshots to retain (default 10). |
10
|
Source code in training/self_play.py
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 | |
size
property
¶
Current number of snapshots in the pool.
snapshot_paths
property
¶
Ordered list of snapshot paths (oldest first, read-only copy).
add(policy, version)
¶
Save policy as a new snapshot and add it to the pool.
The snapshot stores both the state_dict and the constructor
kwargs so the policy can be fully reconstructed later.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
'MAPPOPolicy'
|
The :class: |
required |
version
|
int
|
Monotonically increasing version number embedded in the file name for traceability. |
required |
Returns:
| Type | Description |
|---|---|
Path
|
Path of the newly saved snapshot file. |
Source code in training/self_play.py
sample(rng=None, device=None)
¶
Load and return a uniformly sampled snapshot.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Optional[Generator]
|
Optional NumPy random generator for reproducible sampling.
When |
None
|
device
|
Optional[str]
|
Optional PyTorch device string (e.g. |
None
|
Returns:
| Type | Description |
|---|---|
MAPPOPolicy or None
|
A loaded policy in evaluation mode, or |
Source code in training/self_play.py
sample_latest(device=None)
¶
Load and return the most recently added snapshot.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
Optional[str]
|
Optional PyTorch device string. When provided the loaded policy is moved to device before being returned. |
None
|
Returns:
| Type | Description |
|---|---|
MAPPOPolicy or None
|
The latest policy, or |
Source code in training/self_play.py
WinRateVsPoolCallback
¶
Bases: BaseCallback
Evaluates the current policy vs. a pool opponent and logs win rate.
Every eval_freq environment steps, runs n_eval_episodes episodes
in a temporary :class:~envs.battalion_env.BattalionEnv where Red is
driven by an opponent sampled from pool. The resulting win rate is
logged to W&B (if available) and to the SB3 logger.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool
|
OpponentPool
|
:class: |
required |
eval_freq
|
int
|
How often (in environment steps) to run the evaluation. |
50000
|
n_eval_episodes
|
int
|
Number of episodes per evaluation (default 20). |
20
|
deterministic
|
bool
|
Whether the policy acts deterministically during evaluation
(default |
True
|
use_latest
|
bool
|
When |
False
|
verbose
|
int
|
Verbosity level (0 = silent, 1 = info). |
0
|
Source code in training/self_play.py
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 | |
evaluate_team_vs_pool(policy, opponent, n_blue=2, n_red=2, n_episodes=20, deterministic=True, seed=None, env_kwargs=None)
¶
Evaluate a MAPPO policy (Blue) against opponent (Red) in self-play.
Runs n_episodes episodes of
:class:~envs.multi_battalion_env.MultiBattalionEnv where Blue is
driven by policy and Red is driven by opponent.
For symmetric self-play (n_blue == n_red) the opponent is used
directly as a Red policy. When team sizes differ, the opponent's
shared actor is applied to each Red agent in turn.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
'MAPPOPolicy'
|
The Blue :class: |
required |
opponent
|
'MAPPOPolicy'
|
The frozen Red opponent policy. |
required |
n_blue
|
int
|
Team sizes (must match the training configuration). |
2
|
n_red
|
int
|
Team sizes (must match the training configuration). |
2
|
n_episodes
|
int
|
Number of evaluation episodes (default 20). |
20
|
deterministic
|
bool
|
Blue acts deterministically when |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
env_kwargs
|
Optional[Dict]
|
Extra keyword arguments forwarded to
:class: |
None
|
Returns:
| Type | Description |
|---|---|
float
|
Blue win rate in |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes < 1. |
Source code in training/self_play.py
765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 | |
evaluate_vs_pool(model, opponent, n_episodes=20, deterministic=True, seed=None)
¶
Evaluate model against opponent in self-play episodes.
Runs n_episodes episodes of :class:~envs.battalion_env.BattalionEnv
where Blue is driven by model and Red is driven by opponent.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PPO
|
The policy under evaluation (controls Blue). |
required |
opponent
|
PPO
|
The frozen snapshot policy (controls Red). |
required |
n_episodes
|
int
|
Number of evaluation episodes (default 20). |
20
|
deterministic
|
bool
|
Whether model acts deterministically (default |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
Returns:
| Type | Description |
|---|---|
float
|
Win rate in |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes < 1. |
Source code in training/self_play.py
nash_exploitability_proxy(policy, pool, n_blue=2, n_red=2, n_episodes_per_opponent=10, seed=None, env_kwargs=None)
¶
Estimate Nash exploitability using the current self-play pool.
Evaluates policy (as Blue) against every snapshot in pool and returns the nemesis gap:
.. math::
\text{ExplProxy} = \max_i (1 - \text{wr}_i) - \text{mean}_i (1 - \text{wr}_i)
where :math:\text{wr}_i is the Blue win rate against opponent i.
Interpretation:
0.0— policy performs equally well (or poorly) against every pool member; hard to exploit from the pool.- High value — one pool member significantly outperforms the average against the policy; the policy has an exploitable weakness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
'MAPPOPolicy'
|
The Blue :class: |
required |
pool
|
TeamOpponentPool
|
:class: |
required |
n_blue
|
int
|
Team sizes. |
2
|
n_red
|
int
|
Team sizes. |
2
|
n_episodes_per_opponent
|
int
|
Episodes per pool member (default 10). Smaller than the regular evaluation budget to keep the cost tractable. |
10
|
seed
|
Optional[int]
|
Base random seed. |
None
|
env_kwargs
|
Optional[Dict]
|
Extra kwargs forwarded to
:class: |
None
|
Returns:
| Type | Description |
|---|---|
float
|
Exploitability proxy in |
Source code in training/self_play.py
895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 | |
training.self_play.OpponentPool
¶
Fixed-size pool of frozen PPO policy snapshots.
Snapshots are stored as Stable-Baselines3 .zip files under
pool_dir. The pool keeps at most max_size snapshots; when full,
the oldest snapshot is evicted to make room for the newest one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool_dir
|
str | Path
|
Directory where snapshot |
required |
max_size
|
int
|
Maximum number of snapshots to retain (default 10). |
10
|
Attributes:
| Name | Type | Description |
|---|---|---|
pool_dir |
Path
|
Resolved path of the snapshot directory. |
max_size |
int
|
Maximum number of snapshots retained in the pool. |
Source code in training/self_play.py
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | |
size
property
¶
Current number of snapshots in the pool.
snapshot_paths
property
¶
Ordered list of snapshot paths (oldest first, read-only copy).
add(model, version)
¶
Save model as a new snapshot and add it to the pool.
If the pool already contains max_size snapshots the oldest is removed from disk before saving the new one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PPO
|
The current PPO model to snapshot. |
required |
version
|
int
|
Monotonically increasing version number embedded in the snapshot file name for traceability. |
required |
Returns:
| Type | Description |
|---|---|
Path
|
Path of the newly saved snapshot file (including |
Source code in training/self_play.py
sample(rng=None)
¶
Load and return a uniformly sampled snapshot as a PPO model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Optional[Generator]
|
Optional NumPy random generator for reproducible sampling.
When |
None
|
Returns:
| Type | Description |
|---|---|
PPO or None
|
A loaded PPO model, or |
Source code in training/self_play.py
sample_latest()
¶
Load and return the most recently added snapshot.
Returns:
| Type | Description |
|---|---|
PPO or None
|
The latest PPO model, or |
Source code in training/self_play.py
training.self_play.SelfPlayCallback
¶
Bases: BaseCallback
Periodically snapshots the current policy and updates the Red opponent.
Every snapshot_freq environment steps the current model is saved to
the :class:OpponentPool. If the pool contains at least one snapshot,
a uniformly sampled opponent is loaded and injected into each
environment in vec_env via :meth:~envs.battalion_env.BattalionEnv.set_red_policy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool
|
OpponentPool
|
The :class: |
required |
snapshot_freq
|
int
|
How often (in environment steps) to take a snapshot. |
50000
|
vec_env
|
Optional[VecEnv]
|
The vectorized training environment whose Red opponents should be
updated. When |
None
|
verbose
|
int
|
Verbosity level (0 = silent, 1 = info). |
0
|
Source code in training/self_play.py
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | |
training.self_play.WinRateVsPoolCallback
¶
Bases: BaseCallback
Evaluates the current policy vs. a pool opponent and logs win rate.
Every eval_freq environment steps, runs n_eval_episodes episodes
in a temporary :class:~envs.battalion_env.BattalionEnv where Red is
driven by an opponent sampled from pool. The resulting win rate is
logged to W&B (if available) and to the SB3 logger.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool
|
OpponentPool
|
:class: |
required |
eval_freq
|
int
|
How often (in environment steps) to run the evaluation. |
50000
|
n_eval_episodes
|
int
|
Number of episodes per evaluation (default 20). |
20
|
deterministic
|
bool
|
Whether the policy acts deterministically during evaluation
(default |
True
|
use_latest
|
bool
|
When |
False
|
verbose
|
int
|
Verbosity level (0 = silent, 1 = info). |
0
|
Source code in training/self_play.py
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 | |
training.self_play.evaluate_vs_pool(model, opponent, n_episodes=20, deterministic=True, seed=None)
¶
Evaluate model against opponent in self-play episodes.
Runs n_episodes episodes of :class:~envs.battalion_env.BattalionEnv
where Blue is driven by model and Red is driven by opponent.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PPO
|
The policy under evaluation (controls Blue). |
required |
opponent
|
PPO
|
The frozen snapshot policy (controls Red). |
required |
n_episodes
|
int
|
Number of evaluation episodes (default 20). |
20
|
deterministic
|
bool
|
Whether model acts deterministically (default |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
Returns:
| Type | Description |
|---|---|
float
|
Win rate in |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes < 1. |
Source code in training/self_play.py
training.self_play.TeamOpponentPool
¶
Fixed-size pool of frozen :class:~models.mappo_policy.MAPPOPolicy snapshots.
Snapshots are stored as PyTorch .pt files under pool_dir. Each
file contains the policy state_dict plus the constructor kwargs
(obs_dim, action_dim, state_dim, n_agents,
share_parameters) needed to reconstruct the policy at load time.
The pool keeps at most max_size snapshots; when full, the oldest is
evicted to make room for the newest.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pool_dir
|
str | Path
|
Directory where snapshot |
required |
max_size
|
int
|
Maximum number of snapshots to retain (default 10). |
10
|
Source code in training/self_play.py
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 | |
size
property
¶
Current number of snapshots in the pool.
snapshot_paths
property
¶
Ordered list of snapshot paths (oldest first, read-only copy).
add(policy, version)
¶
Save policy as a new snapshot and add it to the pool.
The snapshot stores both the state_dict and the constructor
kwargs so the policy can be fully reconstructed later.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
'MAPPOPolicy'
|
The :class: |
required |
version
|
int
|
Monotonically increasing version number embedded in the file name for traceability. |
required |
Returns:
| Type | Description |
|---|---|
Path
|
Path of the newly saved snapshot file. |
Source code in training/self_play.py
sample(rng=None, device=None)
¶
Load and return a uniformly sampled snapshot.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng
|
Optional[Generator]
|
Optional NumPy random generator for reproducible sampling.
When |
None
|
device
|
Optional[str]
|
Optional PyTorch device string (e.g. |
None
|
Returns:
| Type | Description |
|---|---|
MAPPOPolicy or None
|
A loaded policy in evaluation mode, or |
Source code in training/self_play.py
sample_latest(device=None)
¶
Load and return the most recently added snapshot.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
Optional[str]
|
Optional PyTorch device string. When provided the loaded policy is moved to device before being returned. |
None
|
Returns:
| Type | Description |
|---|---|
MAPPOPolicy or None
|
The latest policy, or |
Source code in training/self_play.py
training.self_play.evaluate_team_vs_pool(policy, opponent, n_blue=2, n_red=2, n_episodes=20, deterministic=True, seed=None, env_kwargs=None)
¶
Evaluate a MAPPO policy (Blue) against opponent (Red) in self-play.
Runs n_episodes episodes of
:class:~envs.multi_battalion_env.MultiBattalionEnv where Blue is
driven by policy and Red is driven by opponent.
For symmetric self-play (n_blue == n_red) the opponent is used
directly as a Red policy. When team sizes differ, the opponent's
shared actor is applied to each Red agent in turn.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
'MAPPOPolicy'
|
The Blue :class: |
required |
opponent
|
'MAPPOPolicy'
|
The frozen Red opponent policy. |
required |
n_blue
|
int
|
Team sizes (must match the training configuration). |
2
|
n_red
|
int
|
Team sizes (must match the training configuration). |
2
|
n_episodes
|
int
|
Number of evaluation episodes (default 20). |
20
|
deterministic
|
bool
|
Blue acts deterministically when |
True
|
seed
|
Optional[int]
|
Base random seed; episode i uses |
None
|
env_kwargs
|
Optional[Dict]
|
Extra keyword arguments forwarded to
:class: |
None
|
Returns:
| Type | Description |
|---|---|
float
|
Blue win rate in |
Raises:
| Type | Description |
|---|---|
ValueError
|
If n_episodes < 1. |
Source code in training/self_play.py
765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 | |
training.self_play.nash_exploitability_proxy(policy, pool, n_blue=2, n_red=2, n_episodes_per_opponent=10, seed=None, env_kwargs=None)
¶
Estimate Nash exploitability using the current self-play pool.
Evaluates policy (as Blue) against every snapshot in pool and returns the nemesis gap:
.. math::
\text{ExplProxy} = \max_i (1 - \text{wr}_i) - \text{mean}_i (1 - \text{wr}_i)
where :math:\text{wr}_i is the Blue win rate against opponent i.
Interpretation:
0.0— policy performs equally well (or poorly) against every pool member; hard to exploit from the pool.- High value — one pool member significantly outperforms the average against the policy; the policy has an exploitable weakness.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
'MAPPOPolicy'
|
The Blue :class: |
required |
pool
|
TeamOpponentPool
|
:class: |
required |
n_blue
|
int
|
Team sizes. |
2
|
n_red
|
int
|
Team sizes. |
2
|
n_episodes_per_opponent
|
int
|
Episodes per pool member (default 10). Smaller than the regular evaluation budget to keep the cost tractable. |
10
|
seed
|
Optional[int]
|
Base random seed. |
None
|
env_kwargs
|
Optional[Dict]
|
Extra kwargs forwarded to
:class: |
None
|
Returns:
| Type | Description |
|---|---|
float
|
Exploitability proxy in |
Source code in training/self_play.py
895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 | |
Curriculum¶
training.curriculum_scheduler.CurriculumScheduler
¶
Tracks episode outcomes and decides when to promote the curriculum stage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
promote_threshold
|
float
|
Rolling win rate (in |
0.7
|
win_rate_window
|
int
|
Number of most-recent episodes used to compute the rolling win rate.
Defaults to |
50
|
initial_stage
|
CurriculumStage
|
The curriculum stage to begin from. Defaults to
:attr: |
STAGE_1V1
|
Source code in training/curriculum_scheduler.py
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 | |
is_final_stage
property
¶
True when the scheduler is already at the last stage (2v2).
stage
property
¶
The current curriculum stage.
stage_label
property
¶
Human-readable label for the current stage (e.g. "2v1").
total_episodes
property
¶
Total episodes recorded since creation.
env_kwargs()
¶
Return the environment kwargs for the current stage.
These match the n_blue/n_red values in the scenario YAML files
under configs/scenarios/.
Source code in training/curriculum_scheduler.py
log_promotion_event(total_steps, wandb_run=None)
¶
Log a curriculum stage-promotion event to W&B.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
total_steps
|
int
|
Current total environment step count (used as the W&B x-axis). |
required |
wandb_run
|
object
|
An active |
None
|
Source code in training/curriculum_scheduler.py
promote()
¶
Advance to the next curriculum stage and reset the outcome window.
Returns:
| Type | Description |
|---|---|
The new :class:`CurriculumStage` after promotion.
|
|
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If called when already at the final stage. |
Source code in training/curriculum_scheduler.py
record_episode(win)
¶
Record the outcome of a completed episode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
win
|
bool
|
|
required |
Source code in training/curriculum_scheduler.py
should_promote()
¶
Return True if promotion criteria are met.
Criteria:
* At least win_rate_window episodes have been recorded since the
last promotion (or since creation).
* The rolling win rate meets or exceeds promote_threshold.
* The current stage is not already the final stage.
Source code in training/curriculum_scheduler.py
wandb_metrics()
¶
Return a dict of W&B metrics for the current state.
Keys:
* curriculum/stage — integer stage index
* curriculum/stage_label — e.g. "2v1"
* curriculum/win_rate — rolling win rate in [0, 1]
* curriculum/total_episodes — cumulative episode count
Source code in training/curriculum_scheduler.py
win_rate()
¶
Return the rolling win rate over the last win_rate_window episodes.
Returns 0.0 if no episodes have been recorded yet.
Source code in training/curriculum_scheduler.py
training.curriculum_scheduler.CurriculumStage
¶
Bases: IntEnum
Ordered curriculum stages.
The integer value doubles as the stage index used when indexing into the
STAGE_ENV_KWARGS mapping defined in :data:STAGE_ENV_KWARGS.
Source code in training/curriculum_scheduler.py
training.curriculum_scheduler.load_v1_weights_into_mappo(v1_checkpoint_path, mappo_policy, *, strict=False)
¶
Copy shared-trunk weights from a SB3 PPO v1 checkpoint into a MAPPO actor.
SB3 BattalionMlpPolicy (net_arch=[128, 128]) stores its actor
network under mlp_extractor.policy_net.* (even indices are Linear
layers; odd indices are activation functions with no parameters) and the
final action head under action_net.*.
MAPPO MAPPOActor (hidden_sizes=(128, 64)) stores its trunk under
actor.trunk.* where Linear layers are at indices 0, 3, 6, …
(interleaved with LayerNorm at 1, 4, … and Tanh at 2, 5, …) and the
action head under actor.action_mean.*.
Mapping strategy
~~~~~~~~~~~~~~~~
Linear layers are matched positionally — the i-th Linear layer in
mlp_extractor.policy_net maps to the i-th Linear layer in
actor.trunk. Layers whose weight shapes do not match are silently
skipped (or raise ValueError when strict=True).
Additionally log_std → actor.log_std and
action_net.* → actor.action_mean.* are transferred when shapes
match.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
v1_checkpoint_path
|
str | Path
|
Path to the SB3 |
required |
mappo_policy
|
Module
|
A :class: |
required |
strict
|
bool
|
When |
False
|
Returns:
| Type | Description |
|---|---|
A dict with keys ``"loaded"`` (list of transferred MAPPO layer names) and
|
|
``"skipped"`` (list of skipped layer names).
|
|
Source code in training/curriculum_scheduler.py
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 | |
Policy Registry¶
training.policy_registry.PolicyRegistry
¶
Versioned policy registry backed by a JSON manifest.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path, None]
|
Path to the JSON manifest file. The parent directory is created on
:meth: |
'checkpoints/policy_registry.json'
|
Source code in training/policy_registry.py
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 | |
can_save
property
¶
True when the registry has a backing file path.
get(echelon, version)
¶
Return the :class:PolicyEntry for echelon / version.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
echelon
|
Union[Echelon, str]
|
Echelon name or :class: |
required |
version
|
str
|
Version string. |
required |
Returns:
| Type | Description |
|---|---|
PolicyEntry
|
|
Raises:
| Type | Description |
|---|---|
KeyError
|
If no matching entry is found. |
Source code in training/policy_registry.py
list(echelon=None)
¶
Return all registered entries, optionally filtered by echelon.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
echelon
|
Union[Echelon, str, None]
|
When provided, only entries for this echelon are returned. |
None
|
Returns:
| Type | Description |
|---|---|
list[PolicyEntry]
|
A fresh list; mutating it does not affect the registry. |
Source code in training/policy_registry.py
load(echelon, version, device='cpu', **mappo_kwargs)
¶
Load and return a frozen policy for the given echelon / version.
The checkpoint format is inferred from the echelon:
battalion→ MAPPO.ptcheckpoint loaded via :func:~training.utils.freeze_policy.load_and_freeze_mappo. mappo_kwargs (obs_dim,action_dim,state_dim,n_agents) are forwarded to that function and are required for battalion policies.brigade/division→ SB3 PPO.zipcheckpoint loaded via :func:~training.utils.freeze_policy.load_and_freeze_sb3.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
echelon
|
Union[Echelon, str]
|
Echelon name or :class: |
required |
version
|
str
|
Version string. |
required |
device
|
str
|
PyTorch device string (default |
'cpu'
|
**mappo_kwargs
|
Any
|
Extra keyword arguments forwarded to
:func: |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
A frozen :class: |
Raises:
| Type | Description |
|---|---|
KeyError
|
If no entry is registered for this echelon+version. |
FileNotFoundError
|
If the checkpoint file does not exist. |
Source code in training/policy_registry.py
register(echelon, version, path, run_id=None, overwrite=False)
¶
Register a new policy checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
echelon
|
Union[Echelon, str]
|
The HRL echelon: |
required |
version
|
str
|
Caller-assigned version string, e.g. |
required |
path
|
Union[str, Path]
|
File-system path to the checkpoint. |
required |
run_id
|
Optional[str]
|
Optional W&B run ID linked to this checkpoint. |
None
|
overwrite
|
bool
|
When |
False
|
Returns:
| Type | Description |
|---|---|
PolicyEntry
|
The newly created entry. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If an entry with the same echelon+version already exists and
overwrite is |
Source code in training/policy_registry.py
remove(echelon, version)
¶
Remove an entry from the registry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
echelon
|
Union[Echelon, str]
|
Echelon name or :class: |
required |
version
|
str
|
Version string. |
required |
Raises:
| Type | Description |
|---|---|
KeyError
|
If no matching entry is found. |
Source code in training/policy_registry.py
save()
¶
Persist the registry to its JSON manifest file.
Raises:
| Type | Description |
|---|---|
ValueError
|
If the registry was created without a file path. |
Source code in training/policy_registry.py
training.policy_registry.Echelon
¶
Bases: str, Enum
Supported HRL echelon levels.
Source code in training/policy_registry.py
from_str(value)
classmethod
¶
Case-insensitive lookup from string.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
str
|
Echelon name, e.g. |
required |
Returns:
| Type | Description |
|---|---|
Echelon
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If value is not a valid echelon name. |
Source code in training/policy_registry.py
training.policy_registry.PolicyEntry
¶
Bases: NamedTuple
Metadata record for a single registered policy checkpoint.
Attributes:
| Name | Type | Description |
|---|---|---|
echelon |
str
|
The HRL echelon this checkpoint belongs to. |
version |
str
|
Caller-assigned version string, e.g. |
path |
str
|
File-system path to the checkpoint file. |
run_id |
Optional[str]
|
Optional W&B run ID associated with this checkpoint. |
Source code in training/policy_registry.py
Elo Ratings¶
training.elo.EloRegistry
¶
Persistent Elo rating registry backed by a JSON file.
Stores per-agent ratings and game counts. Scripted baseline opponents
("scripted_l1" … "scripted_l5", "random") have fixed seed
ratings defined in :data:BASELINE_RATINGS and are never modified.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path, None]
|
Path to the JSON file used for persistence. The parent directory
is created automatically on :meth: |
'checkpoints/elo_registry.json'
|
Source code in training/elo.py
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 | |
can_save
property
¶
True when the registry has a backing file and can be persisted.
all_ratings()
¶
get_game_count(name)
¶
get_rating(name)
¶
Return the current Elo rating for name.
Falls back to :data:BASELINE_RATINGS for scripted opponents, then
to :data:DEFAULT_RATING for completely unknown agents.
Source code in training/elo.py
save()
¶
Persist current ratings and game counts to the JSON file.
Raises:
| Type | Description |
|---|---|
ValueError
|
If the registry was created without a file path ( |
Source code in training/elo.py
update(agent, opponent, outcome, n_games=1)
¶
Update the Elo rating of agent after a batch of n_games.
The outcome is the average score per game:
1.0— all wins0.5— all draws0.0— all losses
Scripted baseline opponents' ratings are never modified.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
agent
|
str
|
Identifier for the agent whose rating is updated. |
required |
opponent
|
str
|
Identifier of the opponent played against. |
required |
outcome
|
float
|
Average per-game score in |
required |
n_games
|
int
|
Number of games in this batch (used to advance the game-count counter; the K-factor is evaluated at the pre-update count). |
1
|
Returns:
| Type | Description |
|---|---|
float
|
Elo rating delta for agent (positive = rating increased). |
Raises:
| Type | Description |
|---|---|
ValueError
|
If outcome is outside |
Source code in training/elo.py
training.elo.TeamEloRegistry
¶
Bases: EloRegistry
Elo registry specialised for multi-agent team ratings.
Extends :class:EloRegistry with team-specific baseline ratings
(:data:TEAM_BASELINE_RATINGS). All base-class methods work
identically; team baselines are protected against modification just
like the single-agent :data:BASELINE_RATINGS.
Typical usage::
from training.elo import TeamEloRegistry
registry = TeamEloRegistry(path="checkpoints/team_elo.json")
# After a self-play evaluation round:
delta = registry.update(
agent="mappo_blue",
opponent="self_play_pool",
outcome=0.6,
n_games=20,
)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path, None]
|
Path to the JSON persistence file. Pass |
'checkpoints/elo_registry.json'
|
Source code in training/elo.py
get_rating(name)
¶
Return Elo rating for name, checking team baselines.
Look-up order:
- Stored ratings (updated agents).
- :data:
TEAM_BASELINE_RATINGS(multi-agent team baselines). - :data:
BASELINE_RATINGS(single-agent scripted baselines). - :data:
DEFAULT_RATINGfallback.
Source code in training/elo.py
update(agent, opponent, outcome, n_games=1)
¶
Update team Elo, protecting both BASELINE_RATINGS and TEAM_BASELINE_RATINGS.
See :meth:EloRegistry.update for parameter and return-value
documentation.
Raises:
| Type | Description |
|---|---|
ValueError
|
If agent is in :data: |
Source code in training/elo.py
Artifacts¶
training.artifacts.CheckpointManifest
¶
Append-only JSONL checkpoint manifest for local artifact indexing.
Source code in training/artifacts.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 | |
append(row)
¶
Append a single JSON object row to manifest storage.
Source code in training/artifacts.py
has_entry(artifact_path, *, artifact_type, step)
¶
Return whether an identical artifact event is already present.
Source code in training/artifacts.py
known_paths()
¶
latest_entry_for_path(artifact_path)
¶
Return the latest manifest row for a given path, if any.
Source code in training/artifacts.py
latest_periodic(checkpoint_dir, prefix)
¶
Return latest periodic checkpoint from manifest, if available.
Source code in training/artifacts.py
prune_periodic(checkpoint_dir, prefix, keep_last)
¶
Delete old periodic checkpoints on disk, keeping the keep_last newest.
Only files that are both present in the manifest and exist on disk are
considered. The keep_last most recently registered rows (by step,
then timestamp) are retained; all older ones are deleted.
Returns the list of paths that were deleted.
Source code in training/artifacts.py
prune_self_play_snapshots(pool_dir, keep_last)
¶
Delete old self-play snapshots on disk, keeping the keep_last newest.
Returns the list of paths that were deleted.
Source code in training/artifacts.py
register(artifact_path, *, artifact_type, seed, curriculum_level, run_id, config_hash, step)
¶
Register one artifact path if it is not already indexed.
Source code in training/artifacts.py
training.artifacts.checkpoint_name_prefix(*, seed, curriculum_level, enable_v2)
¶
Return periodic checkpoint prefix for the active naming mode.
Source code in training/artifacts.py
training.artifacts.checkpoint_final_stem(*, seed, curriculum_level, enable_v2)
¶
Return final checkpoint stem (without .zip) for the active naming mode.
Source code in training/artifacts.py
training.artifacts.checkpoint_best_filename(*, seed, curriculum_level, enable_v2)
¶
Return best checkpoint filename for the active naming mode.
Source code in training/artifacts.py
training.artifacts.parse_step_from_checkpoint_name(path)
¶
Extract timesteps from a periodic checkpoint file name.
Source code in training/artifacts.py
Benchmarks¶
training.wfm1_benchmark.WFM1Benchmark
¶
Run WFM-1 zero-shot, fine-tuned, and specialist evaluations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
WFM1BenchmarkConfig
|
Benchmark configuration. |
required |
Source code in training/wfm1_benchmark.py
403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 | |
run(wfm1_policy=None, specialist_policies=None)
¶
Run all evaluations and return a summary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
wfm1_policy
|
Optional[Any]
|
Trained :class: |
None
|
specialist_policies
|
Optional[Dict[str, Any]]
|
Optional mapping from scenario name → pre-trained specialist
policy. When |
None
|
Returns:
| Type | Description |
|---|---|
class:`WFM1BenchmarkSummary`
|
|
Source code in training/wfm1_benchmark.py
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 | |
training.wfm1_benchmark.WFM1BenchmarkConfig
dataclass
¶
Configuration for the WFM-1 benchmark.
Attributes:
| Name | Type | Description |
|---|---|---|
n_eval_episodes |
int
|
Number of evaluation episodes per scenario per condition. |
n_scenarios |
int
|
Number of held-out scenarios to evaluate (≤ 20). |
finetune_steps |
int
|
Maximum fine-tuning budget (adapter only) for the "WFM-1 fine-tuned" condition. |
max_steps_per_episode |
int
|
Episode step limit. |
specialist_train_steps |
int
|
How many steps to train each specialist baseline (set to 0 to use a scripted baseline instead — strongly recommended for CI). |
zero_shot_win_rate_threshold |
float
|
Minimum acceptable zero-shot win rate (acceptance criterion). |
finetune_recovery_fraction |
float
|
Fraction of specialist performance that fine-tuned WFM-1 must reach (acceptance criterion). |
Source code in training/wfm1_benchmark.py
training.wfm1_benchmark.WFM1BenchmarkSummary
dataclass
¶
Aggregated WFM-1 benchmark results across all scenarios.
Attributes:
| Name | Type | Description |
|---|---|---|
results |
List[WFM1BenchmarkResult]
|
All individual (scenario, condition) results. |
config |
WFM1BenchmarkConfig
|
The :class: |
Source code in training/wfm1_benchmark.py
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 | |
finetune_recovery
property
¶
Fine-tuned performance as a fraction of specialist performance.
Returns 0.0 when no specialist baseline is available.
mean_finetuned_win_rate
property
¶
Mean fine-tuned win rate across all scenarios.
mean_specialist_win_rate
property
¶
Mean specialist win rate across all scenarios.
mean_zero_shot_win_rate
property
¶
Mean zero-shot win rate across all scenarios.
meets_finetune_criterion
property
¶
Fine-tuned win rate ≥ 80 % of specialist win rate.
meets_zero_shot_criterion
property
¶
Zero-shot win rate ≥ threshold (default 55 %).
write_markdown(path=None)
¶
Write a Markdown report to path.
Defaults to docs/wfm1_benchmark.md.
Source code in training/wfm1_benchmark.py
training.transfer_benchmark.TransferBenchmark
¶
Run procedural-baseline, zero-shot, and fine-tuned evaluations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TransferEvalConfig
|
Evaluation configuration. |
required |
Notes
When policy is None (or a path to a non-existent file) a simple
scripted policy is used: blue units advance at full speed toward the
nearest red unit. This is sufficient for acceptance-criterion testing
in CI without requiring a trained checkpoint.
Source code in training/transfer_benchmark.py
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 | |
run(policy=None, gis_data_dir=None)
¶
Run all three evaluation conditions and return a summary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
policy
|
Optional[Any]
|
Either |
None
|
gis_data_dir
|
Optional[str | Path]
|
Optional directory containing |
None
|
Returns:
| Type | Description |
|---|---|
TransferSummary
|
|
Source code in training/transfer_benchmark.py
write_markdown(summary, path=None)
¶
Write the benchmark summary to a Markdown file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
summary
|
TransferSummary
|
Results from :meth: |
required |
path
|
Optional[str | Path]
|
Output path. Defaults to
|
None
|
Returns:
| Type | Description |
|---|---|
Path
|
Absolute path to the written file. |
Source code in training/transfer_benchmark.py
training.transfer_benchmark.TransferEvalConfig
dataclass
¶
Configuration for a single transfer evaluation run.
Attributes:
| Name | Type | Description |
|---|---|---|
site |
str
|
GIS battle-site identifier — one of |
n_eval_episodes |
int
|
Number of episodes to run for each evaluation condition. |
max_steps_per_episode |
int
|
Episode step budget. |
finetune_steps |
int
|
Maximum fine-tuning steps allowed. The acceptance criterion
requires recovery within |
rows, cols |
Terrain grid resolution used by the GIS importer. |
|
procedural_seed |
int
|
RNG seed used to generate the procedural baseline terrain. |
n_procedural_hills, n_procedural_forests |
Procedural terrain complexity parameters. |
Source code in training/transfer_benchmark.py
training.transfer_benchmark.TransferSummary
dataclass
¶
Aggregated transfer benchmark results.
Attributes:
| Name | Type | Description |
|---|---|---|
site |
str
|
GIS battle-site that was evaluated. |
procedural |
TransferResult
|
Evaluation on procedural terrain (the training distribution). |
zero_shot |
TransferResult
|
Evaluation on GIS terrain with no adaptation. |
finetuned |
TransferResult
|
Evaluation on GIS terrain after fine-tuning. |
config |
TransferEvalConfig
|
The :class: |
Source code in training/transfer_benchmark.py
all_criteria_met
property
¶
Both acceptance criteria are satisfied.
finetuned_drop
property
¶
Win-rate drop from procedural → fine-tuned (positive = drop).
meets_finetune_criterion
property
¶
Fine-tuned drop < 5 pp AND fine-tuning used ≤ 500 k steps.
meets_zero_shot_criterion
property
¶
Zero-shot drop < 20 percentage points.
zero_shot_drop
property
¶
Win-rate drop from procedural → zero-shot (positive = drop).
training.historical_benchmark.HistoricalBenchmark
¶
Run all 50+ historical scenarios and collect comparison results.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
battles_path
|
str | Path
|
Path to the JSON battle database. Defaults to
|
_BATTLES_JSON
|
seed
|
int
|
Random seed passed to :class: |
42
|
Source code in training/historical_benchmark.py
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | |
run()
¶
Run the full benchmark and return a :class:BenchmarkSummary.
Each scenario is run as a 1v1 simulation using the first blue
battalion against the first red battalion (matching the existing
test pattern in tests/test_historical_scenarios.py).
Source code in training/historical_benchmark.py
write_markdown(summary, output_path=_BENCHMARK_MD)
¶
Write the benchmark results to a Markdown file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
summary
|
BenchmarkSummary
|
The :class: |
required |
output_path
|
str | Path
|
Destination file path. Parent directories are created if they do not already exist. |
_BENCHMARK_MD
|
Returns:
| Type | Description |
|---|---|
Path
|
The path to the written file. |