Skip to content

Environments API

The envs package exposes all simulation environments and supporting utilities as a clean programmatic Python API.

Quick-start

from envs import BattalionEnv, RewardWeights, Formation

# Custom reward shaping
rw = RewardWeights(win_bonus=20.0, loss_penalty=-20.0, time_penalty=-0.005)
env = BattalionEnv(reward_weights=rw, curriculum_level=3)
obs, info = env.reset(seed=42)

Environments

envs.battalion_env.BattalionEnv

Bases: Env

1v1 battalion RL environment.

The agent controls the Blue battalion; Red is driven either by a built-in scripted opponent (controlled by curriculum_level) or by an optional red_policy (any object with a predict method, e.g. a Stable-Baselines3 PPO model). When red_policy is provided it takes precedence over the scripted opponent and curriculum_level is ignored for movement and fire decisions.

Curriculum levels (scripted opponent only) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ===== ============================================= Level Red opponent behaviour ===== ============================================= 1 Stationary — Red does not move or fire. 2 Turning only — Red faces Blue but stays put. 3 Advance only — Red turns and advances; no fire. 4 Soft fire — Red turns, advances, fires at 50 % intensity. 5 Full combat — Red turns, advances, fires at 100 % intensity (default). ===== =============================================

Observation space — Box(shape=(17,), dtype=float32) (formations disabled) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ===== =========================== ========= Index Feature Range ===== =========================== ========= 0 blue x / map_width [0, 1] 1 blue y / map_height [0, 1] 2 cos(blue θ) [-1, 1] 3 sin(blue θ) [-1, 1] 4 blue strength [0, 1] 5 blue morale [0, 1] 6 distance to red / diagonal [0, 1] 7 cos(bearing to red) [-1, 1] 8 sin(bearing to red) [-1, 1] 9 red strength [0, 1] 10 red morale [0, 1] 11 step / max_steps [0, 1] 12 blue elevation (normalised) [0, 1] 13 blue cover [0, 1] 14 red elevation (normalised) [0, 1] 15 red cover [0, 1] 16 line-of-sight (0=blocked) [0, 1] ===== =========================== =========

When enable_formations is True, two extra dimensions are appended:

===== ===================================== ========= Index Feature Range ===== ===================================== ========= 17 blue formation / (NUM_FORMATIONS−1) [0, 1] 18 blue transitioning (1=yes) [0, 1] ===== ===================================== =========

When enable_logistics is True, three more dimensions follow the terrain / formations dims:

===== =========================== ========= Index Feature Range ===== =========================== ========= N+0 blue ammo level [0, 1] N+1 blue food level [0, 1] N+2 blue fatigue level [0, 1] ===== =========================== =========

(where N = 17 with formations disabled or 19 with formations enabled).

When enable_weather is True, two more dimensions are appended after any logistics dims:

===== =========================== ========= Index Feature Range ===== =========================== ========= M+0 weather condition id [0, 1] M+1 combined visibility fraction [0, 1] ===== =========================== =========

(where M = N + 3 when logistics are enabled, or N when disabled).

Action space — Box(shape=(3,), dtype=float32) (formations disabled) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ===== =========== ======== ============================================ Index Name Range Effect ===== =========== ======== ============================================ 0 move [-1, 1] Scale max_speed; positive = forward 1 rotate [-1, 1] Scale max_turn_rate; positive = CCW 2 fire [0, 1] Fire intensity this step ===== =========== ======== ============================================

When enable_formations is True, a fourth dimension is added:

===== =========== ======== ============================================ Index Name Range Effect ===== =========== ======== ============================================ 3 formation [0, 3] Desired formation (rounded to nearest int) ===== =========== ======== ============================================

Parameters:

Name Type Description Default
map_width float

Map dimensions in metres (default 1 km × 1 km).

MAP_WIDTH
map_height float

Map dimensions in metres (default 1 km × 1 km).

MAP_WIDTH
max_steps int

Episode length cap (default 500).

MAX_STEPS
terrain Optional[TerrainMap]

Optional :class:~envs.sim.terrain.TerrainMap. When supplied, randomize_terrain is forced to False and this fixed map is used for every episode. Defaults to a flat open plain.

None
randomize_terrain bool

When True (default) and no fixed terrain is supplied, a new procedural terrain is generated from the seeded RNG at the start of each episode. Set to False to keep a static flat plain.

True
hill_speed_factor float

Movement speed multiplier applied to units on maximum-elevation terrain. Must be in (0, 1]. A value of 0.5 means units on the highest hill travel at half their normal speed; 1.0 disables the hill penalty entirely.

0.5
curriculum_level int

Scripted Red opponent difficulty (1–5). Level 1 is the easiest (stationary target); level 5 is full combat. Defaults to 5. Ignored when red_policy is provided.

5
reward_weights Optional[RewardWeights]

:class:~envs.reward.RewardWeights instance with per-component multipliers. Defaults to RewardWeights() (standard shaped reward with the legacy coefficients).

None
red_policy Optional[RedPolicy]

Optional policy object for driving the Red battalion. Must expose a predict(obs, deterministic=False) -> (action, state) method (satisfied by any SB3 model or policy). When supplied, the scripted opponent is bypassed. Use :meth:set_red_policy to swap the policy at runtime (e.g. from a training callback).

None
render_mode Optional[str]

Render mode. None disables rendering. "human" opens a pygame window and renders the simulation in real time; each call to :meth:render displays the current frame.

None
enable_formations bool

When True, the formation system is activated. The action space gains a 4th dimension (desired formation index 0–3) and the observation space gains two extra dimensions (current formation normalised and a transitioning flag). Formation modifiers are applied to firepower, movement speed, and morale resilience (the morale hit from casualties is divided by the unit's morale_resilience before being fed into the morale check). Defaults to False to preserve full backward compatibility.

False
enable_logistics bool

When True, the supply, ammunition, and fatigue model is activated. The observation space gains three extra dimensions (blue ammo, food, fatigue — all normalised to [0, 1]). Ammunition is consumed when firing (weapon jams at zero); fatigue accumulates from movement and combat, penalising speed and accuracy; and battalions can resupply by halting near a friendly supply wagon. Pass a :class:~envs.sim.logistics.LogisticsConfig via logistics_config to tune the rates; if None a default config is used. Defaults to False to preserve full backward compatibility.

False
logistics_config Optional[LogisticsConfig]

Optional :class:~envs.sim.logistics.LogisticsConfig instance. Used when enable_logistics is True. Defaults to LogisticsConfig() if not supplied.

None
enable_weather bool

When True, the weather and time-of-day model is activated. A :class:~envs.sim.weather.WeatherState is sampled at each reset() (or fixed via weather_config). The observation space gains two extra dimensions: the normalised weather condition id and the combined visibility fraction. Weather modifiers are applied to LOS range, fire accuracy, movement speed, and morale. Defaults to False to preserve full backward compatibility.

False
weather_config Optional[WeatherConfig]

Optional :class:~envs.sim.weather.WeatherConfig instance. Used when enable_weather is True. Defaults to WeatherConfig() if not supplied.

None
Source code in envs/battalion_env.py
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
class BattalionEnv(gym.Env):
    """1v1 battalion RL environment.

    The agent controls the **Blue** battalion; **Red** is driven either by a
    built-in scripted opponent (controlled by *curriculum_level*) or by an
    optional *red_policy* (any object with a ``predict`` method, e.g. a
    Stable-Baselines3 ``PPO`` model).  When *red_policy* is provided it takes
    precedence over the scripted opponent and *curriculum_level* is ignored for
    movement and fire decisions.

    Curriculum levels (scripted opponent only)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    =====  =============================================
    Level  Red opponent behaviour
    =====  =============================================
    1      Stationary — Red does not move or fire.
    2      Turning only — Red faces Blue but stays put.
    3      Advance only — Red turns and advances; no fire.
    4      Soft fire — Red turns, advances, fires at 50 % intensity.
    5      Full combat — Red turns, advances, fires at 100 % intensity (default).
    =====  =============================================

    Observation space — ``Box(shape=(17,), dtype=float32)`` (formations disabled)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    =====  ===========================  =========
    Index  Feature                      Range
    =====  ===========================  =========
    0      blue x / map_width           [0, 1]
    1      blue y / map_height          [0, 1]
    2      cos(blue θ)                  [-1, 1]
    3      sin(blue θ)                  [-1, 1]
    4      blue strength                [0, 1]
    5      blue morale                  [0, 1]
    6      distance to red / diagonal   [0, 1]
    7      cos(bearing to red)          [-1, 1]
    8      sin(bearing to red)          [-1, 1]
    9      red strength                 [0, 1]
    10     red morale                   [0, 1]
    11     step / max_steps             [0, 1]
    12     blue elevation (normalised)  [0, 1]
    13     blue cover                   [0, 1]
    14     red elevation (normalised)   [0, 1]
    15     red cover                    [0, 1]
    16     line-of-sight (0=blocked)    [0, 1]
    =====  ===========================  =========

    When *enable_formations* is ``True``, two extra dimensions are appended:

    =====  =====================================  =========
    Index  Feature                                Range
    =====  =====================================  =========
    17     blue formation / (NUM_FORMATIONS−1)    [0, 1]
    18     blue transitioning (1=yes)             [0, 1]
    =====  =====================================  =========

    When *enable_logistics* is ``True``, three more dimensions follow the
    terrain / formations dims:

    =====  ===========================  =========
    Index  Feature                      Range
    =====  ===========================  =========
    N+0    blue ammo level              [0, 1]
    N+1    blue food level              [0, 1]
    N+2    blue fatigue level           [0, 1]
    =====  ===========================  =========

    (where N = 17 with formations disabled or 19 with formations enabled).

    When *enable_weather* is ``True``, two more dimensions are appended after
    any logistics dims:

    =====  ===========================  =========
    Index  Feature                      Range
    =====  ===========================  =========
    M+0    weather condition id         [0, 1]
    M+1    combined visibility fraction [0, 1]
    =====  ===========================  =========

    (where M = N + 3 when logistics are enabled, or N when disabled).

    Action space — ``Box(shape=(3,), dtype=float32)`` (formations disabled)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    =====  ===========  ========  ============================================
    Index  Name         Range     Effect
    =====  ===========  ========  ============================================
    0      move         [-1, 1]   Scale ``max_speed``; positive = forward
    1      rotate       [-1, 1]   Scale ``max_turn_rate``; positive = CCW
    2      fire         [0, 1]    Fire intensity this step
    =====  ===========  ========  ============================================

    When *enable_formations* is ``True``, a fourth dimension is added:

    =====  ===========  ========  ============================================
    Index  Name         Range     Effect
    =====  ===========  ========  ============================================
    3      formation    [0, 3]    Desired formation (rounded to nearest int)
    =====  ===========  ========  ============================================

    Parameters
    ----------
    map_width, map_height:
        Map dimensions in metres (default 1 km × 1 km).
    max_steps:
        Episode length cap (default 500).
    terrain:
        Optional :class:`~envs.sim.terrain.TerrainMap`.  When supplied,
        *randomize_terrain* is forced to ``False`` and this fixed map is
        used for every episode.  Defaults to a flat open plain.
    randomize_terrain:
        When ``True`` (default) and no fixed *terrain* is supplied, a new
        procedural terrain is generated from the seeded RNG at the start
        of each episode.  Set to ``False`` to keep a static flat plain.
    hill_speed_factor:
        Movement speed multiplier applied to units on maximum-elevation
        terrain.  Must be in ``(0, 1]``.  A value of ``0.5`` means units
        on the highest hill travel at half their normal speed; ``1.0``
        disables the hill penalty entirely.
    curriculum_level:
        Scripted Red opponent difficulty (1–5).  Level 1 is the easiest
        (stationary target); level 5 is full combat.  Defaults to ``5``.
        Ignored when *red_policy* is provided.
    reward_weights:
        :class:`~envs.reward.RewardWeights` instance with per-component
        multipliers.  Defaults to ``RewardWeights()`` (standard shaped
        reward with the legacy coefficients).
    red_policy:
        Optional policy object for driving the Red battalion.  Must expose
        a ``predict(obs, deterministic=False) -> (action, state)`` method
        (satisfied by any SB3 model or policy).  When supplied, the scripted
        opponent is bypassed.  Use :meth:`set_red_policy` to swap the policy
        at runtime (e.g. from a training callback).
    render_mode:
        Render mode.  ``None`` disables rendering.  ``"human"`` opens a
        pygame window and renders the simulation in real time; each call to
        :meth:`render` displays the current frame.
    enable_formations:
        When ``True``, the formation system is activated.  The action space
        gains a 4th dimension (desired formation index 0–3) and the
        observation space gains two extra dimensions (current formation
        normalised and a transitioning flag).  Formation modifiers are
        applied to firepower, movement speed, and morale resilience (the
        morale hit from casualties is divided by the unit's
        ``morale_resilience`` before being fed into the morale check).
        Defaults to ``False`` to preserve full backward compatibility.
    enable_logistics:
        When ``True``, the supply, ammunition, and fatigue model is
        activated.  The observation space gains three extra dimensions
        (blue ammo, food, fatigue — all normalised to ``[0, 1]``).
        Ammunition is consumed when firing (weapon jams at zero);
        fatigue accumulates from movement and combat, penalising speed and
        accuracy; and battalions can resupply by halting near a friendly
        supply wagon.  Pass a :class:`~envs.sim.logistics.LogisticsConfig`
        via *logistics_config* to tune the rates; if ``None`` a default
        config is used.  Defaults to ``False`` to preserve full backward
        compatibility.
    logistics_config:
        Optional :class:`~envs.sim.logistics.LogisticsConfig` instance.
        Used when *enable_logistics* is ``True``.  Defaults to
        ``LogisticsConfig()`` if not supplied.
    enable_weather:
        When ``True``, the weather and time-of-day model is activated.
        A :class:`~envs.sim.weather.WeatherState` is sampled at each
        ``reset()`` (or fixed via *weather_config*).  The observation
        space gains two extra dimensions: the normalised weather condition
        id and the combined visibility fraction.  Weather modifiers are
        applied to LOS range, fire accuracy, movement speed, and morale.
        Defaults to ``False`` to preserve full backward compatibility.
    weather_config:
        Optional :class:`~envs.sim.weather.WeatherConfig` instance.
        Used when *enable_weather* is ``True``.  Defaults to
        ``WeatherConfig()`` if not supplied.
    """

    metadata: dict = {"render_modes": ["human"]}

    def __init__(
        self,
        map_width: float = MAP_WIDTH,
        map_height: float = MAP_HEIGHT,
        max_steps: int = MAX_STEPS,
        terrain: Optional[TerrainMap] = None,
        randomize_terrain: bool = True,
        hill_speed_factor: float = 0.5,
        curriculum_level: int = 5,
        reward_weights: Optional[RewardWeights] = None,
        red_policy: Optional[RedPolicy] = None,
        render_mode: Optional[str] = None,
        morale_config: Optional[MoraleConfig] = None,
        enable_formations: bool = False,
        enable_logistics: bool = False,
        logistics_config: Optional[LogisticsConfig] = None,
        enable_weather: bool = False,
        weather_config: Optional[WeatherConfig] = None,
    ) -> None:
        super().__init__()

        # ------------------------------------------------------------------
        # Argument validation
        # ------------------------------------------------------------------
        if float(map_width) <= 0:
            raise ValueError(f"map_width must be positive, got {map_width}")
        if float(map_height) <= 0:
            raise ValueError(f"map_height must be positive, got {map_height}")
        if int(max_steps) < 1:
            raise ValueError(f"max_steps must be >= 1, got {max_steps}")
        if not (0.0 < float(hill_speed_factor) <= 1.0):
            raise ValueError(
                f"hill_speed_factor must be in (0, 1], got {hill_speed_factor}"
            )
        _curriculum_level = int(curriculum_level)
        if _curriculum_level not in range(1, NUM_CURRICULUM_LEVELS + 1):
            raise ValueError(
                f"curriculum_level must be in 1–{NUM_CURRICULUM_LEVELS}, "
                f"got {curriculum_level}"
            )
        if render_mode is not None and render_mode not in self.metadata["render_modes"]:
            raise ValueError(
                f"Unsupported render_mode {render_mode!r}. "
                f"Supported modes: {self.metadata['render_modes']}"
            )

        self.map_width = float(map_width)
        self.map_height = float(map_height)
        self.map_diagonal = math.sqrt(self.map_width ** 2 + self.map_height ** 2)
        self.max_steps = int(max_steps)
        self.hill_speed_factor = float(hill_speed_factor)
        self.curriculum_level = _curriculum_level
        self.reward_weights: RewardWeights = (
            reward_weights if reward_weights is not None else RewardWeights()
        )
        self.morale_config: Optional[MoraleConfig] = morale_config
        self.enable_formations: bool = bool(enable_formations)
        self.enable_logistics: bool = bool(enable_logistics)
        self.logistics_config: LogisticsConfig = (
            logistics_config if logistics_config is not None else LogisticsConfig()
        )
        self.enable_weather: bool = bool(enable_weather)
        self.weather_config: WeatherConfig = (
            weather_config if weather_config is not None else WeatherConfig()
        )
        # When an explicit terrain is supplied, terrain randomisation is
        # disabled so the caller's fixed map is used every episode.
        self.randomize_terrain: bool = bool(randomize_terrain) and (terrain is None)
        self._supplied_terrain: Optional[TerrainMap] = terrain
        self.terrain: TerrainEngine = (
            TerrainEngine.from_terrain_map(terrain)
            if terrain is not None
            else TerrainEngine.flat(map_width, map_height)
        )
        self.render_mode = render_mode
        # Policy-based Red opponent (overrides scripted behaviour when set).
        self.red_policy: Optional[RedPolicy] = red_policy

        # Renderer — created lazily when render_mode="human".
        self._renderer: Optional[Any] = None

        # ------------------------------------------------------------------
        # Observation space
        # ------------------------------------------------------------------
        # Base 17-dimensional observation (terrain + combat features).
        # When formations are enabled two extra dims are appended:
        #   [17] blue formation normalised in [0, 1]
        #   [18] transitioning flag in {0, 1}
        # When logistics are enabled three more dims follow:
        #   [N+0] blue ammo level
        #   [N+1] blue food level
        #   [N+2] blue fatigue level
        # When weather is enabled two further dims are appended:
        #   [M+0] weather condition id (normalised to [0, 1])
        #   [M+1] combined visibility fraction [0, 1]
        obs_low = np.array(
            [0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0,
             0.0, 0.0, 0.0, 0.0, 0.0],
            dtype=np.float32,
        )
        obs_high = np.array(
            [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
             1.0, 1.0, 1.0, 1.0, 1.0],
            dtype=np.float32,
        )
        if self.enable_formations:
            obs_low = np.append(obs_low, [0.0, 0.0]).astype(np.float32)
            obs_high = np.append(obs_high, [1.0, 1.0]).astype(np.float32)
        if self.enable_logistics:
            obs_low = np.append(obs_low, [0.0, 0.0, 0.0]).astype(np.float32)
            obs_high = np.append(obs_high, [1.0, 1.0, 1.0]).astype(np.float32)
        if self.enable_weather:
            obs_low = np.append(obs_low, [0.0, 0.0]).astype(np.float32)
            obs_high = np.append(obs_high, [1.0, 1.0]).astype(np.float32)
        self.observation_space = spaces.Box(
            low=obs_low, high=obs_high, dtype=np.float32
        )

        # ------------------------------------------------------------------
        # Action space — (move, rotate, fire) + optional formation choice
        # ------------------------------------------------------------------
        if self.enable_formations:
            self.action_space = spaces.Box(
                low=np.array([-1.0, -1.0, 0.0, 0.0], dtype=np.float32),
                high=np.array([1.0, 1.0, 1.0, float(NUM_FORMATIONS - 1)], dtype=np.float32),
                dtype=np.float32,
            )
        else:
            self.action_space = spaces.Box(
                low=np.array([-1.0, -1.0, 0.0], dtype=np.float32),
                high=np.array([1.0, 1.0, 1.0], dtype=np.float32),
                dtype=np.float32,
            )

        # Internal state — populated by reset()
        self.blue: Battalion | None = None
        self.red: Battalion | None = None
        self.blue_state: CombatState | None = None
        self.red_state: CombatState | None = None
        self._step_count: int = 0

        # Logistics state — populated by reset() when enable_logistics=True
        self.blue_logistics: LogisticsState | None = None
        self.red_logistics: LogisticsState | None = None
        self.blue_wagon: SupplyWagon | None = None
        self.red_wagon: SupplyWagon | None = None

        # Weather state — populated by reset() when enable_weather=True
        self.weather_state: WeatherState | None = None

    # ------------------------------------------------------------------
    # Gymnasium API
    # ------------------------------------------------------------------

    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ) -> tuple[np.ndarray, dict]:
        """Reset the environment and return the initial observation.

        When *randomize_terrain* is ``True`` (the default when no fixed
        terrain map is passed to ``__init__``), a new procedural terrain
        is generated from the seeded RNG on every call.  Passing the same
        *seed* therefore always produces the same terrain layout and unit
        positions.

        Blue spawns in the western half of the map facing roughly east;
        Red spawns in the eastern half facing roughly west.
        """
        super().reset(seed=seed)
        rng = self.np_random

        # Generate a fresh terrain map from the seeded RNG each episode.
        if self.randomize_terrain:
            self.terrain = TerrainEngine.generate_random(
                rng=rng,
                width=self.map_width,
                height=self.map_height,
            )
        elif self._supplied_terrain is not None:
            self.terrain = TerrainEngine.from_terrain_map(self._supplied_terrain)

        # Blue: western quarter, roughly eastward
        bx = float(rng.uniform(0.1 * self.map_width, 0.4 * self.map_width))
        by = float(rng.uniform(0.1 * self.map_height, 0.9 * self.map_height))
        b_theta = float(rng.uniform(-math.pi / 4, math.pi / 4))

        # Red: eastern quarter, roughly westward
        rx = float(rng.uniform(0.6 * self.map_width, 0.9 * self.map_width))
        ry = float(rng.uniform(0.1 * self.map_height, 0.9 * self.map_height))
        r_theta = float(math.pi + rng.uniform(-math.pi / 4, math.pi / 4))

        self.blue = Battalion(x=bx, y=by, theta=b_theta, strength=1.0, team=0)
        self.red = Battalion(x=rx, y=ry, theta=r_theta, strength=1.0, team=1)
        self.blue_state = CombatState()
        self.red_state = CombatState()
        self._step_count = 0

        # --- Logistics initialisation ---
        if self.enable_logistics:
            cfg = self.logistics_config
            self.blue_logistics = LogisticsState(
                ammo=cfg.initial_ammo,
                food=cfg.initial_food,
                fatigue=0.0,
            )
            self.red_logistics = LogisticsState(
                ammo=cfg.initial_ammo,
                food=cfg.initial_food,
                fatigue=0.0,
            )
            # Blue supply wagon: behind Blue (further west), same y as Blue
            self.blue_wagon = SupplyWagon(
                x=max(0.0, bx - 150.0),
                y=by,
                team=0,
                strength=cfg.wagon_max_strength,
            )
            # Red supply wagon: behind Red (further east), same y as Red
            self.red_wagon = SupplyWagon(
                x=min(self.map_width, rx + 150.0),
                y=ry,
                team=1,
                strength=cfg.wagon_max_strength,
            )
        else:
            self.blue_logistics = None
            self.red_logistics = None
            self.blue_wagon = None
            self.red_wagon = None

        # --- Weather initialisation ---
        if self.enable_weather:
            self.weather_state = sample_weather(rng, self.weather_config)
        else:
            self.weather_state = None

        return self._get_obs(), {}

    def step(
        self, action: np.ndarray
    ) -> tuple[np.ndarray, float, bool, bool, dict]:
        """Advance the environment by one step.

        Parameters
        ----------
        action:
            Array of shape ``(3,)``: ``[move, rotate, fire]``.

        Returns
        -------
        observation, reward, terminated, truncated, info
        """
        if self.blue is None or self.red is None:
            raise RuntimeError("Call reset() before step().")

        action = np.asarray(action, dtype=np.float32)
        move_cmd   = float(np.clip(action[0], -1.0, 1.0))
        rotate_cmd = float(np.clip(action[1], -1.0, 1.0))
        fire_cmd   = float(np.clip(action[2],  0.0, 1.0))

        # --- Formation action (optional 4th dimension) ---
        if self.enable_formations and len(action) >= 4:
            desired_formation_idx = int(np.clip(np.round(action[3]), 0, NUM_FORMATIONS - 1))
            self._request_formation_change(self.blue, desired_formation_idx)

        # --- Advance formation transitions for both units ---
        if self.enable_formations:
            self._advance_formation_transition(self.blue)
            self._advance_formation_transition(self.red)

        # --- Apply agent action to Blue ---
        # Normal movement when not routing, OR when morale_config is not set
        # (legacy path: agent retains control even while CombatState.is_routing).
        # When morale_config is set and the unit is already routing, movement is
        # suppressed here and overridden by rout_velocity() after morale update.
        if not (self.morale_config is not None and self.blue_state.is_routing):
            # Rotation (Battalion.rotate clamps to max_turn_rate internally)
            self.blue.rotate(rotate_cmd * self.blue.max_turn_rate)
            # Forward/backward movement along current heading, slowed on hills
            speed_mod = self.terrain.get_speed_modifier(
                self.blue.x, self.blue.y, self.hill_speed_factor
            )
            # Formation speed modifier applied when formations are active
            if self.enable_formations:
                speed_mod *= get_attributes(Formation(self.blue.formation)).speed_modifier
            # Fatigue speed penalty applied when logistics are active
            if self.enable_logistics and self.blue_logistics is not None:
                speed_mod *= get_fatigue_speed_modifier(
                    self.blue_logistics, self.logistics_config
                )
            # Weather speed penalty applied when weather is active
            if self.enable_weather and self.weather_state is not None:
                speed_mod *= get_speed_modifier(self.weather_state)
            vx = math.cos(self.blue.theta) * move_cmd * self.blue.max_speed * speed_mod
            vy = math.sin(self.blue.theta) * move_cmd * self.blue.max_speed * speed_mod
            self.blue.move(vx, vy, dt=DT)
        # Clamp to map bounds
        self.blue.x = float(np.clip(self.blue.x, 0.0, self.map_width))
        self.blue.y = float(np.clip(self.blue.y, 0.0, self.map_height))

        # --- Red opponent (scripted or policy-based) ---
        # When morale_config is set and Red is already routing, skip scripted/policy
        # movement; rout_velocity() is applied after morale update instead.
        red_skip_normal_movement = self.morale_config is not None and self.red_state.is_routing
        # Capture Red's pre-movement position to track actual displacement for
        # fatigue calculation (avoids hard-coded assumptions about behaviour).
        _red_x_before = self.red.x
        _red_y_before = self.red.y
        if self.red_policy is not None:
            red_obs = self._get_red_obs()
            red_action, _ = self.red_policy.predict(red_obs, deterministic=False)
            red_action = np.asarray(red_action, dtype=np.float32)
            red_fire_cmd = float(np.clip(red_action[2], 0.0, 1.0))
            if not red_skip_normal_movement:
                self._step_red_policy(red_action)
        else:
            red_fire_cmd = self._red_fire_intensity()
            if not red_skip_normal_movement:
                self._step_red()

        # --- Combat resolution (simultaneous) ---
        self.blue_state.reset_step_accumulators()
        self.red_state.reset_step_accumulators()

        # Logistics — apply ammo consumption and modifiers before damage
        # computation.  effective_*_cmd is what is actually fed into the
        # damage formula; the raw requested intensities are preserved for
        # fatigue tracking.
        effective_fire_cmd = fire_cmd
        effective_red_fire_cmd = red_fire_cmd
        if self.enable_logistics:
            lc = self.logistics_config
            if self.blue_logistics is not None:
                effective_fire_cmd = consume_ammo(
                    self.blue_logistics, fire_cmd, lc
                )
                effective_fire_cmd *= get_ammo_modifier(self.blue_logistics, lc)
                effective_fire_cmd *= get_fatigue_accuracy_modifier(
                    self.blue_logistics, lc
                )
            if self.red_logistics is not None:
                effective_red_fire_cmd = consume_ammo(
                    self.red_logistics, red_fire_cmd, lc
                )
                effective_red_fire_cmd *= get_ammo_modifier(self.red_logistics, lc)
                effective_red_fire_cmd *= get_fatigue_accuracy_modifier(
                    self.red_logistics, lc
                )

        # Weather accuracy penalty — applied after logistics modifiers.
        # Affects both sides equally (rain/fog/night degrades all musketry).
        if self.enable_weather and self.weather_state is not None:
            weather_acc = get_accuracy_modifier(self.weather_state)
            effective_fire_cmd *= weather_acc
            effective_red_fire_cmd *= weather_acc

        raw_b2r = compute_fire_damage(self.blue, self.red, intensity=effective_fire_cmd)
        raw_r2b = compute_fire_damage(self.red, self.blue, intensity=effective_red_fire_cmd)

        # Apply terrain cover at each target's position
        raw_b2r = self.terrain.apply_cover_modifier(self.red.x, self.red.y, raw_b2r)
        raw_r2b = self.terrain.apply_cover_modifier(self.blue.x, self.blue.y, raw_r2b)

        # Apply casualties
        dmg_b2r = apply_casualties(self.red, self.red_state, raw_b2r)
        dmg_r2b = apply_casualties(self.blue, self.blue_state, raw_r2b)

        # Wagon targeting — when logistics are active, each side's fire can
        # also damage the enemy supply wagon if it is within range and in the
        # fire arc.  This makes wagons emergent high-priority targets.
        if self.enable_logistics:
            self._apply_wagon_damage(
                shooter=self.red,
                wagon=self.blue_wagon,
                intensity=effective_red_fire_cmd,
            )
            self._apply_wagon_damage(
                shooter=self.blue,
                wagon=self.red_wagon,
                intensity=effective_fire_cmd,
            )

        # Formation morale resilience — scale down the accumulated damage that
        # feeds into the morale check.  A higher morale_resilience means the
        # unit absorbs the same strength loss with less morale penalty (SQUARE
        # = 1.5×, i.e. morale hit is 2/3 of what unformed troops would suffer).
        # This only affects morale propagation; strength damage is unchanged.
        if self.enable_formations:
            blue_resilience = get_attributes(Formation(self.blue.formation)).morale_resilience
            red_resilience = get_attributes(Formation(self.red.formation)).morale_resilience
            self.blue_state.accumulated_damage /= blue_resilience
            self.red_state.accumulated_damage /= red_resilience

        # Compute enemy distance for distance-based morale recovery
        dx = self.blue.x - self.red.x
        dy = self.blue.y - self.red.y
        enemy_dist = float(math.sqrt(dx * dx + dy * dy))

        # Morale checks — use enhanced update_morale when a MoraleConfig is
        # provided, otherwise fall back to the basic morale_check.
        if self.morale_config is not None:
            mc = self.morale_config
            blue_flank = compute_flank_stressor(
                self.red.x, self.red.y,
                self.blue.x, self.blue.y, self.blue.theta,
                dmg_r2b,
            )
            red_flank = compute_flank_stressor(
                self.blue.x, self.blue.y,
                self.red.x, self.red.y, self.red.theta,
                dmg_b2r,
            )
            # Weather morale stressor — added to flanking penalty so that
            # update_morale applies it through the same deduction path.
            if self.enable_weather and self.weather_state is not None:
                weather_ms = get_morale_stressor(self.weather_state)
                blue_flank += weather_ms
                red_flank += weather_ms
            update_morale(
                self.blue_state,
                enemy_dist=enemy_dist,
                config=mc,
                flank_penalty=blue_flank,
                rng=self.np_random,
            )
            update_morale(
                self.red_state,
                enemy_dist=enemy_dist,
                config=mc,
                flank_penalty=red_flank,
                rng=self.np_random,
            )
        else:
            # Weather morale stressor — added to accumulated_damage before morale_check
            # so that it enters the same deduction pipeline (recovery + clamping)
            # as the MoraleConfig path.
            if self.enable_weather and self.weather_state is not None:
                weather_ms = get_morale_stressor(self.weather_state)
                if weather_ms > 0.0:
                    self.blue_state.accumulated_damage += weather_ms
                    self.red_state.accumulated_damage += weather_ms
            morale_check(self.blue_state, rng=self.np_random)
            morale_check(self.red_state, rng=self.np_random)

        # Sync Battalion flags from CombatState
        self.blue.morale = self.blue_state.morale
        self.red.morale  = self.red_state.morale
        self.blue.routed = self.blue_state.is_routing
        self.red.routed  = self.red_state.is_routing

        # --- Logistics update (fatigue, food, resupply) ---
        if self.enable_logistics:
            lc = self.logistics_config
            # Determine movement activity from raw move commands
            blue_moved = abs(move_cmd) > 0.01
            # Determine Red movement from actual position change (works for both
            # scripted and policy-driven opponents at all curriculum levels).
            red_moved = (
                abs(self.red.x - _red_x_before) > 1e-4
                or abs(self.red.y - _red_y_before) > 1e-4
            )
            blue_fired = effective_fire_cmd > 0.0
            red_fired = effective_red_fire_cmd > 0.0
            if self.blue_logistics is not None:
                update_fatigue(self.blue_logistics, blue_moved, blue_fired, lc)
                consume_food(self.blue_logistics, lc)
                # Resupply only available when halted near a friendly wagon
                # (not moving and not firing) — models historical supply practice.
                if self.blue_wagon is not None and not blue_moved and not blue_fired:
                    check_resupply(
                        self.blue_logistics,
                        self.blue.x, self.blue.y,
                        self.blue_wagon, lc,
                    )
            if self.red_logistics is not None:
                update_fatigue(self.red_logistics, red_moved, red_fired, lc)
                consume_food(self.red_logistics, lc)
                # Resupply only available when halted near a friendly wagon
                if self.red_wagon is not None and not red_moved and not red_fired:
                    check_resupply(
                        self.red_logistics,
                        self.red.x, self.red.y,
                        self.red_wagon, lc,
                    )

        # --- Post-morale rout movement (morale_config mode only) ---
        # Applied *after* morale update so that units route on the same step
        # routing is triggered, not just on subsequent steps.  Also handles
        # already-routing units whose pre-step movement was suppressed above.
        if self.morale_config is not None:
            if self.blue_state.is_routing:
                vx, vy = rout_velocity(
                    self.blue.x, self.blue.y,
                    self.red.x, self.red.y,
                    self.blue.max_speed,
                    self.morale_config,
                )
                self.blue.move(vx, vy, dt=DT)
                self.blue.x = float(np.clip(self.blue.x, 0.0, self.map_width))
                self.blue.y = float(np.clip(self.blue.y, 0.0, self.map_height))
            if self.red_state.is_routing:
                self._step_routing_red()

        self._step_count += 1

        # --- Weather progression (time-of-day advancement) ---
        if self.enable_weather and self.weather_state is not None:
            step_weather(self.weather_state, self.weather_config)

        # --- Termination ---
        blue_done = (
            self.blue_state.is_routing or self.blue.strength <= DESTROYED_THRESHOLD
        )
        red_done = (
            self.red_state.is_routing or self.red.strength <= DESTROYED_THRESHOLD
        )
        terminated = blue_done or red_done
        truncated  = (not terminated) and (self._step_count >= self.max_steps)

        # --- Reward ---
        blue_won = red_done and not blue_done
        blue_lost = blue_done and not red_done
        reward_comps = compute_reward(
            dmg_b2r=dmg_b2r,
            dmg_r2b=dmg_r2b,
            blue_strength=float(self.blue.strength),
            blue_won=blue_won,
            blue_lost=blue_lost,
            weights=self.reward_weights,
            enemy_routed=self.red_state.is_routing,
            own_routing=self.blue_state.is_routing,
        )

        info: dict = {
            "blue_damage_dealt": float(dmg_b2r),
            "red_damage_dealt":  float(dmg_r2b),
            "blue_routed":       self.blue_state.is_routing,
            "red_routed":        self.red_state.is_routing,
            "step_count":        self._step_count,
            **reward_comps.as_dict(),
        }

        # Add logistics info when the system is active
        if self.enable_logistics:
            if self.blue_logistics is not None:
                info["blue_ammo"]    = float(self.blue_logistics.ammo)
                info["blue_food"]    = float(self.blue_logistics.food)
                info["blue_fatigue"] = float(self.blue_logistics.fatigue)
            if self.red_logistics is not None:
                info["red_ammo"]    = float(self.red_logistics.ammo)
                info["red_food"]    = float(self.red_logistics.food)
                info["red_fatigue"] = float(self.red_logistics.fatigue)

        # Add weather info when the system is active
        if self.enable_weather and self.weather_state is not None:
            info["weather_condition"] = int(self.weather_state.condition)
            info["time_of_day"]       = int(self.weather_state.time_of_day)
            info["visibility_fraction"] = float(
                get_visibility_fraction(self.weather_state)
            )

        return self._get_obs(), reward_comps.total, terminated, truncated, info

    def render(self) -> None:
        """Render the current environment state.

        When ``render_mode="human"`` a pygame window is opened on the first
        call and kept alive for subsequent calls.  The window is closed by
        :meth:`close`.  When ``render_mode`` is ``None`` this is a no-op.
        """
        if self.render_mode != "human":
            return
        if self.blue is None or self.red is None:
            return  # nothing to render before the first reset()

        if self._renderer is None:
            from envs.rendering.renderer import BattalionRenderer  # noqa: PLC0415
            self._renderer = BattalionRenderer(self.map_width, self.map_height)

        self._renderer.render_frame(
            self.blue,
            self.red,
            terrain=self.terrain,
            step=self._step_count,
        )

    def close(self) -> None:
        """Clean up resources, including the pygame window if open."""
        if self._renderer is not None:
            self._renderer.close()
            self._renderer = None

    def set_red_policy(self, policy: Optional[RedPolicy]) -> None:
        """Swap the Red opponent policy at runtime.

        Parameters
        ----------
        policy:
            New policy to use for Red, or ``None`` to revert to the
            scripted opponent.
        """
        self.red_policy = policy

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _norm_elevation(self, x: float, y: float) -> float:
        """Return terrain elevation at ``(x, y)`` normalised to ``[0, 1]``.

        Uses :attr:`~envs.sim.terrain_engine.TerrainEngine.max_elevation` as
        the normalisation factor (the same value used internally by
        :meth:`~envs.sim.terrain.TerrainMap.get_speed_modifier`).
        Returns ``0.0`` on flat terrain.
        """
        max_e = self.terrain.max_elevation
        if max_e <= 0.0:
            return 0.0
        elev = self.terrain.get_elevation(x, y)
        return float(np.clip(elev / max_e, 0.0, 1.0))

    def _get_obs(self) -> np.ndarray:
        """Build and return the normalised observation.

        Returns a 17-dimensional array when formations, logistics, and weather
        are all disabled.  Two extra dims are added when *enable_formations* is
        ``True``; three more when *enable_logistics* is ``True``; two more when
        *enable_weather* is ``True``.
        """
        b = self.blue
        r = self.red

        dx = r.x - b.x
        dy = r.y - b.y
        dist = math.sqrt(dx ** 2 + dy ** 2)
        bearing = math.atan2(dy, dx)

        # Terrain features
        blue_elev = self._norm_elevation(b.x, b.y)
        blue_cover = self.terrain.get_cover(b.x, b.y)
        red_elev = self._norm_elevation(r.x, r.y)
        red_cover = self.terrain.get_cover(r.x, r.y)

        # LOS — additionally blocked when the enemy is beyond the effective
        # weather visibility range.
        terrain_los = self.terrain.bresenham_los(b.x, b.y, r.x, r.y)
        if self.enable_weather and self.weather_state is not None:
            vis_range = get_effective_visibility_range(
                self.weather_state, self.weather_config
            )
            los = 1.0 if (terrain_los and dist <= vis_range) else 0.0
        else:
            los = 1.0 if terrain_los else 0.0

        base = [
            b.x / self.map_width,                    # [0] blue x norm
            b.y / self.map_height,                   # [1] blue y norm
            math.cos(b.theta),                       # [2] cos(blue θ)
            math.sin(b.theta),                       # [3] sin(blue θ)
            b.strength,                              # [4] blue strength
            b.morale,                                # [5] blue morale
            min(dist / self.map_diagonal, 1.0),      # [6] dist norm
            math.cos(bearing),                       # [7] cos(bearing)
            math.sin(bearing),                       # [8] sin(bearing)
            r.strength,                              # [9] red strength
            r.morale,                                # [10] red morale
            min(self._step_count / self.max_steps, 1.0),  # [11] step norm
            blue_elev,                               # [12] blue elevation
            blue_cover,                              # [13] blue cover
            red_elev,                                # [14] red elevation
            red_cover,                               # [15] red cover
            los,                                     # [16] line-of-sight
        ]
        if self.enable_formations:
            base.append(b.formation / float(NUM_FORMATIONS - 1))  # [17] formation norm
            base.append(1.0 if b.target_formation is not None else 0.0)  # [18] transitioning
        if self.enable_logistics:
            ls = self.blue_logistics
            base.append(ls.ammo if ls is not None else 1.0)      # [N+0] ammo
            base.append(ls.food if ls is not None else 1.0)      # [N+1] food
            base.append(ls.fatigue if ls is not None else 0.0)   # [N+2] fatigue
        if self.enable_weather and self.weather_state is not None:
            base.append(
                self.weather_state.condition.value / float(NUM_CONDITIONS - 1)
            )  # [M+0] weather condition id
            base.append(get_visibility_fraction(self.weather_state))  # [M+1] visibility
        obs = np.array(base, dtype=np.float32)
        # Clip to declared bounds to guard against floating-point drift
        return np.clip(obs, self.observation_space.low, self.observation_space.high)

    def _get_red_obs(self) -> np.ndarray:
        """Build the observation from **Red's** perspective.

        The observation is symmetric to Blue's: Red sees itself where Blue
        normally sits and Blue where Red normally sits.  This means a frozen
        Red policy trained as Blue sees the same observation schema.
        When formations are enabled the two extra formation dims reflect Red's
        formation state.  When weather is enabled Red also receives weather
        features.
        """
        r = self.red
        b = self.blue

        dx = b.x - r.x
        dy = b.y - r.y
        dist = math.sqrt(dx ** 2 + dy ** 2)
        bearing = math.atan2(dy, dx)

        # Terrain features from Red's perspective
        red_elev = self._norm_elevation(r.x, r.y)
        red_cover = self.terrain.get_cover(r.x, r.y)
        blue_elev = self._norm_elevation(b.x, b.y)
        blue_cover = self.terrain.get_cover(b.x, b.y)

        terrain_los = self.terrain.bresenham_los(r.x, r.y, b.x, b.y)
        if self.enable_weather and self.weather_state is not None:
            vis_range = get_effective_visibility_range(
                self.weather_state, self.weather_config
            )
            los = 1.0 if (terrain_los and dist <= vis_range) else 0.0
        else:
            los = 1.0 if terrain_los else 0.0

        base = [
            r.x / self.map_width,                    # [0] red x norm
            r.y / self.map_height,                   # [1] red y norm
            math.cos(r.theta),                       # [2] cos(red θ)
            math.sin(r.theta),                       # [3] sin(red θ)
            r.strength,                              # [4] red strength
            r.morale,                                # [5] red morale
            min(dist / self.map_diagonal, 1.0),      # [6] dist norm
            math.cos(bearing),                       # [7] cos(bearing to blue)
            math.sin(bearing),                       # [8] sin(bearing to blue)
            b.strength,                              # [9] blue strength
            b.morale,                                # [10] blue morale
            min(self._step_count / self.max_steps, 1.0),  # [11] step norm
            red_elev,                                # [12] red elevation
            red_cover,                               # [13] red cover
            blue_elev,                               # [14] blue elevation
            blue_cover,                              # [15] blue cover
            los,                                     # [16] line-of-sight
        ]
        if self.enable_formations:
            base.append(r.formation / float(NUM_FORMATIONS - 1))  # [17] formation norm
            base.append(1.0 if r.target_formation is not None else 0.0)  # [18] transitioning
        if self.enable_logistics:
            ls = self.red_logistics
            base.append(ls.ammo if ls is not None else 1.0)      # [N+0] ammo
            base.append(ls.food if ls is not None else 1.0)      # [N+1] food
            base.append(ls.fatigue if ls is not None else 0.0)   # [N+2] fatigue
        if self.enable_weather and self.weather_state is not None:
            base.append(
                self.weather_state.condition.value / float(NUM_CONDITIONS - 1)
            )  # [M+0] weather condition id
            base.append(get_visibility_fraction(self.weather_state))  # [M+1] visibility
        obs = np.array(base, dtype=np.float32)
        return np.clip(obs, self.observation_space.low, self.observation_space.high)

    def _step_red_policy(self, action: np.ndarray) -> None:
        """Apply a policy action to the Red battalion (movement only).

        Movement is applied using the same physics as Blue; fire intensity
        is extracted from ``action[2]`` and used in :meth:`step` directly.

        Parameters
        ----------
        action:
            Array of shape ``(3,)`` or ``(4,)``: ``[move, rotate, fire, ...]``.
        """
        r = self.red
        move_cmd   = float(np.clip(action[0], -1.0, 1.0))
        rotate_cmd = float(np.clip(action[1], -1.0, 1.0))

        r.rotate(rotate_cmd * r.max_turn_rate)
        speed_mod = self.terrain.get_speed_modifier(r.x, r.y, self.hill_speed_factor)
        if self.enable_formations:
            speed_mod *= get_attributes(Formation(r.formation)).speed_modifier
        if self.enable_logistics and self.red_logistics is not None:
            speed_mod *= get_fatigue_speed_modifier(self.red_logistics, self.logistics_config)
        if self.enable_weather and self.weather_state is not None:
            speed_mod *= get_speed_modifier(self.weather_state)
        vx = math.cos(r.theta) * move_cmd * r.max_speed * speed_mod
        vy = math.sin(r.theta) * move_cmd * r.max_speed * speed_mod
        r.move(vx, vy, dt=DT)
        r.x = float(np.clip(r.x, 0.0, self.map_width))
        r.y = float(np.clip(r.y, 0.0, self.map_height))

    def _step_red(self) -> None:
        """Scripted Red opponent: behaviour depends on *curriculum_level*.

        This method handles Red's **movement only**.  Red's fire is resolved
        in :meth:`step` using :meth:`_red_fire_intensity` so that damage
        computation remains centralised (simultaneous with Blue's fire).

        ========  ==========================================
        Level     Movement behaviour
        ========  ==========================================
        1         Stationary — Red does not move.
        2         Turning only — Red faces Blue; no advance.
        3–5       Red turns and advances to within 80 % of fire range.
        ========  ==========================================
        """
        level = self.curriculum_level

        # Level 1: Red stands completely still.
        if level == 1:
            return

        r = self.red
        b = self.blue

        dx = b.x - r.x
        dy = b.y - r.y
        target_angle = math.atan2(dy, dx)

        # Rotate toward Blue via the shortest arc (levels 2–5).
        delta = (target_angle - r.theta + math.pi) % (2 * math.pi) - math.pi
        r.rotate(delta)

        # Level 2: turn only, no advance.
        if level == 2:
            return

        # Advance if outside 80 % of fire range (levels 3–5).
        dist = math.sqrt(dx ** 2 + dy ** 2)
        if dist > r.fire_range * 0.8:
            speed_mod = self.terrain.get_speed_modifier(
                r.x, r.y, self.hill_speed_factor
            )
            if self.enable_formations:
                speed_mod *= get_attributes(Formation(r.formation)).speed_modifier
            if self.enable_logistics and self.red_logistics is not None:
                speed_mod *= get_fatigue_speed_modifier(
                    self.red_logistics, self.logistics_config
                )
            if self.enable_weather and self.weather_state is not None:
                speed_mod *= get_speed_modifier(self.weather_state)
            vx = math.cos(r.theta) * r.max_speed * speed_mod
            vy = math.sin(r.theta) * r.max_speed * speed_mod
            r.move(vx, vy, dt=DT)
            r.x = float(np.clip(r.x, 0.0, self.map_width))
            r.y = float(np.clip(r.y, 0.0, self.map_height))

    def _red_fire_intensity(self) -> float:
        """Return the fire intensity Red uses this step, based on curriculum level.

        ========  ==========================
        Level     Red fire intensity
        ========  ==========================
        1–3       0.0  (Red does not fire)
        4         0.5  (50 % intensity)
        5         1.0  (full intensity)
        ========  ==========================
        """
        level = self.curriculum_level
        if level <= 3:
            return 0.0
        if level == 4:
            return 0.5
        return 1.0

    def _step_routing_red(self) -> None:
        """Apply forced rout movement to the Red battalion.

        Called when Red is in a routing state instead of the normal scripted
        or policy movement.  Red flees directly away from Blue.  Rout
        movement is only applied when a :class:`MoraleConfig` is set on the
        environment; otherwise this method is a no-op so that the episode
        terminates normally.
        """
        if self.morale_config is None:
            return
        r = self.red
        b = self.blue
        vx, vy = rout_velocity(
            r.x, r.y,
            b.x, b.y,
            r.max_speed,
            self.morale_config,
        )
        r.move(vx, vy, dt=DT)
        r.x = float(np.clip(r.x, 0.0, self.map_width))
        r.y = float(np.clip(r.y, 0.0, self.map_height))

    def _apply_wagon_damage(
        self,
        shooter: Battalion,
        wagon: Optional[SupplyWagon],
        intensity: float,
    ) -> None:
        """Apply fire damage from *shooter* to an enemy *wagon* if in range.

        Uses a simple linear range-factor scaled by ``BASE_FIRE_DAMAGE`` and
        shooter strength, without formation modifiers (wagons are unformed
        logistics assets, not fighting units).  The calculation mirrors the
        legacy path of :func:`~envs.sim.combat.compute_fire_damage`.

        Parameters
        ----------
        shooter:
            The firing battalion.
        wagon:
            The target supply wagon.  If ``None`` or already destroyed, this
            method is a no-op.
        intensity:
            Effective fire intensity (after ammo and fatigue modifiers).
        """
        if wagon is None or not wagon.is_alive:
            return
        if intensity <= 0.0:
            return

        dx = wagon.x - shooter.x
        dy = wagon.y - shooter.y
        dist = math.sqrt(dx * dx + dy * dy)

        # Range check
        if dist > shooter.fire_range or shooter.fire_range <= 0.0:
            return

        # Fire-arc check (same formula as in_fire_arc / can_fire_at)
        angle_to_wagon = math.atan2(dy, dx)
        angle_diff = abs(
            (angle_to_wagon - shooter.theta + math.pi) % (2 * math.pi) - math.pi
        )
        if angle_diff >= shooter.fire_arc:
            return

        # Linear range falloff
        rf = 1.0 - (dist / shooter.fire_range)
        damage = BASE_FIRE_DAMAGE * float(intensity) * rf * max(0.0, shooter.strength)
        wagon.take_damage(damage)

    # ------------------------------------------------------------------
    # Formation helpers (only called when enable_formations is True)
    # ------------------------------------------------------------------

    def _request_formation_change(self, unit: Battalion, desired_idx: int) -> None:
        """Request a formation change for *unit*.

        Ignored if the unit is already in or transitioning to *desired_idx*.
        Starts a transition by setting ``target_formation`` and
        ``formation_transition_steps`` on the unit.

        Parameters
        ----------
        unit:
            The :class:`~envs.sim.battalion.Battalion` whose formation to change.
        desired_idx:
            Target :class:`~envs.sim.formations.Formation` value (integer).
        """
        desired_idx = int(np.clip(desired_idx, 0, NUM_FORMATIONS - 1))
        # Already there or already heading there — no-op
        if desired_idx == unit.formation:
            return
        if unit.target_formation is not None and unit.target_formation == desired_idx:
            return
        # Start transition
        steps = get_transition_steps(
            Formation(unit.formation), Formation(desired_idx)
        )
        unit.target_formation = desired_idx
        unit.formation_transition_steps = steps

    @staticmethod
    def _advance_formation_transition(unit: Battalion) -> None:
        """Advance *unit*'s formation transition by one step.

        Calls :func:`~envs.sim.formations.compute_transition_state` and
        writes the result back to the unit's formation fields.

        Parameters
        ----------
        unit:
            The :class:`~envs.sim.battalion.Battalion` to advance.
        """
        new_formation, new_target, new_steps = compute_transition_state(
            Formation(unit.formation),
            Formation(unit.target_formation) if unit.target_formation is not None else None,
            unit.formation_transition_steps,
        )
        unit.formation = int(new_formation)
        unit.target_formation = int(new_target) if new_target is not None else None
        unit.formation_transition_steps = new_steps

close()

Clean up resources, including the pygame window if open.

Source code in envs/battalion_env.py
def close(self) -> None:
    """Clean up resources, including the pygame window if open."""
    if self._renderer is not None:
        self._renderer.close()
        self._renderer = None

render()

Render the current environment state.

When render_mode="human" a pygame window is opened on the first call and kept alive for subsequent calls. The window is closed by :meth:close. When render_mode is None this is a no-op.

Source code in envs/battalion_env.py
def render(self) -> None:
    """Render the current environment state.

    When ``render_mode="human"`` a pygame window is opened on the first
    call and kept alive for subsequent calls.  The window is closed by
    :meth:`close`.  When ``render_mode`` is ``None`` this is a no-op.
    """
    if self.render_mode != "human":
        return
    if self.blue is None or self.red is None:
        return  # nothing to render before the first reset()

    if self._renderer is None:
        from envs.rendering.renderer import BattalionRenderer  # noqa: PLC0415
        self._renderer = BattalionRenderer(self.map_width, self.map_height)

    self._renderer.render_frame(
        self.blue,
        self.red,
        terrain=self.terrain,
        step=self._step_count,
    )

reset(*, seed=None, options=None)

Reset the environment and return the initial observation.

When randomize_terrain is True (the default when no fixed terrain map is passed to __init__), a new procedural terrain is generated from the seeded RNG on every call. Passing the same seed therefore always produces the same terrain layout and unit positions.

Blue spawns in the western half of the map facing roughly east; Red spawns in the eastern half facing roughly west.

Source code in envs/battalion_env.py
def reset(
    self,
    *,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> tuple[np.ndarray, dict]:
    """Reset the environment and return the initial observation.

    When *randomize_terrain* is ``True`` (the default when no fixed
    terrain map is passed to ``__init__``), a new procedural terrain
    is generated from the seeded RNG on every call.  Passing the same
    *seed* therefore always produces the same terrain layout and unit
    positions.

    Blue spawns in the western half of the map facing roughly east;
    Red spawns in the eastern half facing roughly west.
    """
    super().reset(seed=seed)
    rng = self.np_random

    # Generate a fresh terrain map from the seeded RNG each episode.
    if self.randomize_terrain:
        self.terrain = TerrainEngine.generate_random(
            rng=rng,
            width=self.map_width,
            height=self.map_height,
        )
    elif self._supplied_terrain is not None:
        self.terrain = TerrainEngine.from_terrain_map(self._supplied_terrain)

    # Blue: western quarter, roughly eastward
    bx = float(rng.uniform(0.1 * self.map_width, 0.4 * self.map_width))
    by = float(rng.uniform(0.1 * self.map_height, 0.9 * self.map_height))
    b_theta = float(rng.uniform(-math.pi / 4, math.pi / 4))

    # Red: eastern quarter, roughly westward
    rx = float(rng.uniform(0.6 * self.map_width, 0.9 * self.map_width))
    ry = float(rng.uniform(0.1 * self.map_height, 0.9 * self.map_height))
    r_theta = float(math.pi + rng.uniform(-math.pi / 4, math.pi / 4))

    self.blue = Battalion(x=bx, y=by, theta=b_theta, strength=1.0, team=0)
    self.red = Battalion(x=rx, y=ry, theta=r_theta, strength=1.0, team=1)
    self.blue_state = CombatState()
    self.red_state = CombatState()
    self._step_count = 0

    # --- Logistics initialisation ---
    if self.enable_logistics:
        cfg = self.logistics_config
        self.blue_logistics = LogisticsState(
            ammo=cfg.initial_ammo,
            food=cfg.initial_food,
            fatigue=0.0,
        )
        self.red_logistics = LogisticsState(
            ammo=cfg.initial_ammo,
            food=cfg.initial_food,
            fatigue=0.0,
        )
        # Blue supply wagon: behind Blue (further west), same y as Blue
        self.blue_wagon = SupplyWagon(
            x=max(0.0, bx - 150.0),
            y=by,
            team=0,
            strength=cfg.wagon_max_strength,
        )
        # Red supply wagon: behind Red (further east), same y as Red
        self.red_wagon = SupplyWagon(
            x=min(self.map_width, rx + 150.0),
            y=ry,
            team=1,
            strength=cfg.wagon_max_strength,
        )
    else:
        self.blue_logistics = None
        self.red_logistics = None
        self.blue_wagon = None
        self.red_wagon = None

    # --- Weather initialisation ---
    if self.enable_weather:
        self.weather_state = sample_weather(rng, self.weather_config)
    else:
        self.weather_state = None

    return self._get_obs(), {}

set_red_policy(policy)

Swap the Red opponent policy at runtime.

Parameters:

Name Type Description Default
policy Optional[RedPolicy]

New policy to use for Red, or None to revert to the scripted opponent.

required
Source code in envs/battalion_env.py
def set_red_policy(self, policy: Optional[RedPolicy]) -> None:
    """Swap the Red opponent policy at runtime.

    Parameters
    ----------
    policy:
        New policy to use for Red, or ``None`` to revert to the
        scripted opponent.
    """
    self.red_policy = policy

step(action)

Advance the environment by one step.

Parameters:

Name Type Description Default
action ndarray

Array of shape (3,): [move, rotate, fire].

required

Returns:

Type Description
(observation, reward, terminated, truncated, info)
Source code in envs/battalion_env.py
def step(
    self, action: np.ndarray
) -> tuple[np.ndarray, float, bool, bool, dict]:
    """Advance the environment by one step.

    Parameters
    ----------
    action:
        Array of shape ``(3,)``: ``[move, rotate, fire]``.

    Returns
    -------
    observation, reward, terminated, truncated, info
    """
    if self.blue is None or self.red is None:
        raise RuntimeError("Call reset() before step().")

    action = np.asarray(action, dtype=np.float32)
    move_cmd   = float(np.clip(action[0], -1.0, 1.0))
    rotate_cmd = float(np.clip(action[1], -1.0, 1.0))
    fire_cmd   = float(np.clip(action[2],  0.0, 1.0))

    # --- Formation action (optional 4th dimension) ---
    if self.enable_formations and len(action) >= 4:
        desired_formation_idx = int(np.clip(np.round(action[3]), 0, NUM_FORMATIONS - 1))
        self._request_formation_change(self.blue, desired_formation_idx)

    # --- Advance formation transitions for both units ---
    if self.enable_formations:
        self._advance_formation_transition(self.blue)
        self._advance_formation_transition(self.red)

    # --- Apply agent action to Blue ---
    # Normal movement when not routing, OR when morale_config is not set
    # (legacy path: agent retains control even while CombatState.is_routing).
    # When morale_config is set and the unit is already routing, movement is
    # suppressed here and overridden by rout_velocity() after morale update.
    if not (self.morale_config is not None and self.blue_state.is_routing):
        # Rotation (Battalion.rotate clamps to max_turn_rate internally)
        self.blue.rotate(rotate_cmd * self.blue.max_turn_rate)
        # Forward/backward movement along current heading, slowed on hills
        speed_mod = self.terrain.get_speed_modifier(
            self.blue.x, self.blue.y, self.hill_speed_factor
        )
        # Formation speed modifier applied when formations are active
        if self.enable_formations:
            speed_mod *= get_attributes(Formation(self.blue.formation)).speed_modifier
        # Fatigue speed penalty applied when logistics are active
        if self.enable_logistics and self.blue_logistics is not None:
            speed_mod *= get_fatigue_speed_modifier(
                self.blue_logistics, self.logistics_config
            )
        # Weather speed penalty applied when weather is active
        if self.enable_weather and self.weather_state is not None:
            speed_mod *= get_speed_modifier(self.weather_state)
        vx = math.cos(self.blue.theta) * move_cmd * self.blue.max_speed * speed_mod
        vy = math.sin(self.blue.theta) * move_cmd * self.blue.max_speed * speed_mod
        self.blue.move(vx, vy, dt=DT)
    # Clamp to map bounds
    self.blue.x = float(np.clip(self.blue.x, 0.0, self.map_width))
    self.blue.y = float(np.clip(self.blue.y, 0.0, self.map_height))

    # --- Red opponent (scripted or policy-based) ---
    # When morale_config is set and Red is already routing, skip scripted/policy
    # movement; rout_velocity() is applied after morale update instead.
    red_skip_normal_movement = self.morale_config is not None and self.red_state.is_routing
    # Capture Red's pre-movement position to track actual displacement for
    # fatigue calculation (avoids hard-coded assumptions about behaviour).
    _red_x_before = self.red.x
    _red_y_before = self.red.y
    if self.red_policy is not None:
        red_obs = self._get_red_obs()
        red_action, _ = self.red_policy.predict(red_obs, deterministic=False)
        red_action = np.asarray(red_action, dtype=np.float32)
        red_fire_cmd = float(np.clip(red_action[2], 0.0, 1.0))
        if not red_skip_normal_movement:
            self._step_red_policy(red_action)
    else:
        red_fire_cmd = self._red_fire_intensity()
        if not red_skip_normal_movement:
            self._step_red()

    # --- Combat resolution (simultaneous) ---
    self.blue_state.reset_step_accumulators()
    self.red_state.reset_step_accumulators()

    # Logistics — apply ammo consumption and modifiers before damage
    # computation.  effective_*_cmd is what is actually fed into the
    # damage formula; the raw requested intensities are preserved for
    # fatigue tracking.
    effective_fire_cmd = fire_cmd
    effective_red_fire_cmd = red_fire_cmd
    if self.enable_logistics:
        lc = self.logistics_config
        if self.blue_logistics is not None:
            effective_fire_cmd = consume_ammo(
                self.blue_logistics, fire_cmd, lc
            )
            effective_fire_cmd *= get_ammo_modifier(self.blue_logistics, lc)
            effective_fire_cmd *= get_fatigue_accuracy_modifier(
                self.blue_logistics, lc
            )
        if self.red_logistics is not None:
            effective_red_fire_cmd = consume_ammo(
                self.red_logistics, red_fire_cmd, lc
            )
            effective_red_fire_cmd *= get_ammo_modifier(self.red_logistics, lc)
            effective_red_fire_cmd *= get_fatigue_accuracy_modifier(
                self.red_logistics, lc
            )

    # Weather accuracy penalty — applied after logistics modifiers.
    # Affects both sides equally (rain/fog/night degrades all musketry).
    if self.enable_weather and self.weather_state is not None:
        weather_acc = get_accuracy_modifier(self.weather_state)
        effective_fire_cmd *= weather_acc
        effective_red_fire_cmd *= weather_acc

    raw_b2r = compute_fire_damage(self.blue, self.red, intensity=effective_fire_cmd)
    raw_r2b = compute_fire_damage(self.red, self.blue, intensity=effective_red_fire_cmd)

    # Apply terrain cover at each target's position
    raw_b2r = self.terrain.apply_cover_modifier(self.red.x, self.red.y, raw_b2r)
    raw_r2b = self.terrain.apply_cover_modifier(self.blue.x, self.blue.y, raw_r2b)

    # Apply casualties
    dmg_b2r = apply_casualties(self.red, self.red_state, raw_b2r)
    dmg_r2b = apply_casualties(self.blue, self.blue_state, raw_r2b)

    # Wagon targeting — when logistics are active, each side's fire can
    # also damage the enemy supply wagon if it is within range and in the
    # fire arc.  This makes wagons emergent high-priority targets.
    if self.enable_logistics:
        self._apply_wagon_damage(
            shooter=self.red,
            wagon=self.blue_wagon,
            intensity=effective_red_fire_cmd,
        )
        self._apply_wagon_damage(
            shooter=self.blue,
            wagon=self.red_wagon,
            intensity=effective_fire_cmd,
        )

    # Formation morale resilience — scale down the accumulated damage that
    # feeds into the morale check.  A higher morale_resilience means the
    # unit absorbs the same strength loss with less morale penalty (SQUARE
    # = 1.5×, i.e. morale hit is 2/3 of what unformed troops would suffer).
    # This only affects morale propagation; strength damage is unchanged.
    if self.enable_formations:
        blue_resilience = get_attributes(Formation(self.blue.formation)).morale_resilience
        red_resilience = get_attributes(Formation(self.red.formation)).morale_resilience
        self.blue_state.accumulated_damage /= blue_resilience
        self.red_state.accumulated_damage /= red_resilience

    # Compute enemy distance for distance-based morale recovery
    dx = self.blue.x - self.red.x
    dy = self.blue.y - self.red.y
    enemy_dist = float(math.sqrt(dx * dx + dy * dy))

    # Morale checks — use enhanced update_morale when a MoraleConfig is
    # provided, otherwise fall back to the basic morale_check.
    if self.morale_config is not None:
        mc = self.morale_config
        blue_flank = compute_flank_stressor(
            self.red.x, self.red.y,
            self.blue.x, self.blue.y, self.blue.theta,
            dmg_r2b,
        )
        red_flank = compute_flank_stressor(
            self.blue.x, self.blue.y,
            self.red.x, self.red.y, self.red.theta,
            dmg_b2r,
        )
        # Weather morale stressor — added to flanking penalty so that
        # update_morale applies it through the same deduction path.
        if self.enable_weather and self.weather_state is not None:
            weather_ms = get_morale_stressor(self.weather_state)
            blue_flank += weather_ms
            red_flank += weather_ms
        update_morale(
            self.blue_state,
            enemy_dist=enemy_dist,
            config=mc,
            flank_penalty=blue_flank,
            rng=self.np_random,
        )
        update_morale(
            self.red_state,
            enemy_dist=enemy_dist,
            config=mc,
            flank_penalty=red_flank,
            rng=self.np_random,
        )
    else:
        # Weather morale stressor — added to accumulated_damage before morale_check
        # so that it enters the same deduction pipeline (recovery + clamping)
        # as the MoraleConfig path.
        if self.enable_weather and self.weather_state is not None:
            weather_ms = get_morale_stressor(self.weather_state)
            if weather_ms > 0.0:
                self.blue_state.accumulated_damage += weather_ms
                self.red_state.accumulated_damage += weather_ms
        morale_check(self.blue_state, rng=self.np_random)
        morale_check(self.red_state, rng=self.np_random)

    # Sync Battalion flags from CombatState
    self.blue.morale = self.blue_state.morale
    self.red.morale  = self.red_state.morale
    self.blue.routed = self.blue_state.is_routing
    self.red.routed  = self.red_state.is_routing

    # --- Logistics update (fatigue, food, resupply) ---
    if self.enable_logistics:
        lc = self.logistics_config
        # Determine movement activity from raw move commands
        blue_moved = abs(move_cmd) > 0.01
        # Determine Red movement from actual position change (works for both
        # scripted and policy-driven opponents at all curriculum levels).
        red_moved = (
            abs(self.red.x - _red_x_before) > 1e-4
            or abs(self.red.y - _red_y_before) > 1e-4
        )
        blue_fired = effective_fire_cmd > 0.0
        red_fired = effective_red_fire_cmd > 0.0
        if self.blue_logistics is not None:
            update_fatigue(self.blue_logistics, blue_moved, blue_fired, lc)
            consume_food(self.blue_logistics, lc)
            # Resupply only available when halted near a friendly wagon
            # (not moving and not firing) — models historical supply practice.
            if self.blue_wagon is not None and not blue_moved and not blue_fired:
                check_resupply(
                    self.blue_logistics,
                    self.blue.x, self.blue.y,
                    self.blue_wagon, lc,
                )
        if self.red_logistics is not None:
            update_fatigue(self.red_logistics, red_moved, red_fired, lc)
            consume_food(self.red_logistics, lc)
            # Resupply only available when halted near a friendly wagon
            if self.red_wagon is not None and not red_moved and not red_fired:
                check_resupply(
                    self.red_logistics,
                    self.red.x, self.red.y,
                    self.red_wagon, lc,
                )

    # --- Post-morale rout movement (morale_config mode only) ---
    # Applied *after* morale update so that units route on the same step
    # routing is triggered, not just on subsequent steps.  Also handles
    # already-routing units whose pre-step movement was suppressed above.
    if self.morale_config is not None:
        if self.blue_state.is_routing:
            vx, vy = rout_velocity(
                self.blue.x, self.blue.y,
                self.red.x, self.red.y,
                self.blue.max_speed,
                self.morale_config,
            )
            self.blue.move(vx, vy, dt=DT)
            self.blue.x = float(np.clip(self.blue.x, 0.0, self.map_width))
            self.blue.y = float(np.clip(self.blue.y, 0.0, self.map_height))
        if self.red_state.is_routing:
            self._step_routing_red()

    self._step_count += 1

    # --- Weather progression (time-of-day advancement) ---
    if self.enable_weather and self.weather_state is not None:
        step_weather(self.weather_state, self.weather_config)

    # --- Termination ---
    blue_done = (
        self.blue_state.is_routing or self.blue.strength <= DESTROYED_THRESHOLD
    )
    red_done = (
        self.red_state.is_routing or self.red.strength <= DESTROYED_THRESHOLD
    )
    terminated = blue_done or red_done
    truncated  = (not terminated) and (self._step_count >= self.max_steps)

    # --- Reward ---
    blue_won = red_done and not blue_done
    blue_lost = blue_done and not red_done
    reward_comps = compute_reward(
        dmg_b2r=dmg_b2r,
        dmg_r2b=dmg_r2b,
        blue_strength=float(self.blue.strength),
        blue_won=blue_won,
        blue_lost=blue_lost,
        weights=self.reward_weights,
        enemy_routed=self.red_state.is_routing,
        own_routing=self.blue_state.is_routing,
    )

    info: dict = {
        "blue_damage_dealt": float(dmg_b2r),
        "red_damage_dealt":  float(dmg_r2b),
        "blue_routed":       self.blue_state.is_routing,
        "red_routed":        self.red_state.is_routing,
        "step_count":        self._step_count,
        **reward_comps.as_dict(),
    }

    # Add logistics info when the system is active
    if self.enable_logistics:
        if self.blue_logistics is not None:
            info["blue_ammo"]    = float(self.blue_logistics.ammo)
            info["blue_food"]    = float(self.blue_logistics.food)
            info["blue_fatigue"] = float(self.blue_logistics.fatigue)
        if self.red_logistics is not None:
            info["red_ammo"]    = float(self.red_logistics.ammo)
            info["red_food"]    = float(self.red_logistics.food)
            info["red_fatigue"] = float(self.red_logistics.fatigue)

    # Add weather info when the system is active
    if self.enable_weather and self.weather_state is not None:
        info["weather_condition"] = int(self.weather_state.condition)
        info["time_of_day"]       = int(self.weather_state.time_of_day)
        info["visibility_fraction"] = float(
            get_visibility_fraction(self.weather_state)
        )

    return self._get_obs(), reward_comps.total, terminated, truncated, info

envs.brigade_env.BrigadeEnv

Bases: Env

Gymnasium environment for a brigade-level HRL commander.

Parameters:

Name Type Description Default
n_blue int

Number of Blue battalions controlled by the brigade.

2
n_red int

Number of Red opponent battalions.

2
map_width float

Map width in metres.

MAP_WIDTH
map_height float

Map height in metres.

MAP_HEIGHT
max_steps int

Maximum primitive-step episode length.

MAX_STEPS
options Optional[list[Option]]

Option vocabulary. None uses :func:~envs.options.make_default_options. When None, the vocabulary is built using temporal_ratio as the option max_steps cap.

None
temporal_ratio int

Number of primitive battalion steps per brigade macro-step (option duration cap). Ignored when an explicit options list is supplied. Must be >= 1. Corresponds to the hyperparameter swept in E3.5.

10
battalion_policy

Optional frozen :class:~models.mappo_policy.MAPPOPolicy used to drive Red agents. All parameters are detached (requires_grad=False). When None, Red agents are stationary (zero primitive actions).

None
red_random bool

When True and no battalion_policy is set, Red agents take random primitive actions. Ignored when battalion_policy is set.

False
randomize_terrain bool

Pass-through to :class:~envs.multi_battalion_env.MultiBattalionEnv.

True
visibility_radius float

Pass-through to :class:~envs.multi_battalion_env.MultiBattalionEnv.

600.0
render_mode Optional[str]

None or "human" — delegated to the inner env.

None
Source code in envs/brigade_env.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
class BrigadeEnv(gym.Env):
    """Gymnasium environment for a brigade-level HRL commander.

    Parameters
    ----------
    n_blue:
        Number of Blue battalions controlled by the brigade.
    n_red:
        Number of Red opponent battalions.
    map_width:
        Map width in metres.
    map_height:
        Map height in metres.
    max_steps:
        Maximum primitive-step episode length.
    options:
        Option vocabulary.  ``None`` uses :func:`~envs.options.make_default_options`.
        When ``None``, the vocabulary is built using ``temporal_ratio`` as the
        option ``max_steps`` cap.
    temporal_ratio:
        Number of primitive battalion steps per brigade macro-step (option
        duration cap).  Ignored when an explicit ``options`` list is supplied.
        Must be ``>= 1``.  Corresponds to the hyperparameter swept in E3.5.
    battalion_policy:
        Optional frozen :class:`~models.mappo_policy.MAPPOPolicy` used to
        drive Red agents.  All parameters are detached (``requires_grad=False``).
        When ``None``, Red agents are stationary (zero primitive actions).
    red_random:
        When ``True`` and no ``battalion_policy`` is set, Red agents take
        random primitive actions.  Ignored when ``battalion_policy`` is set.
    randomize_terrain:
        Pass-through to :class:`~envs.multi_battalion_env.MultiBattalionEnv`.
    visibility_radius:
        Pass-through to :class:`~envs.multi_battalion_env.MultiBattalionEnv`.
    render_mode:
        ``None`` or ``"human"`` — delegated to the inner env.
    """

    metadata: dict = {"render_modes": ["human"], "name": "brigade_v0"}

    def __init__(
        self,
        n_blue: int = 2,
        n_red: int = 2,
        map_width: float = MAP_WIDTH,
        map_height: float = MAP_HEIGHT,
        max_steps: int = MAX_STEPS,
        options: Optional[list[Option]] = None,
        temporal_ratio: int = 10,
        battalion_policy=None,
        red_random: bool = False,
        randomize_terrain: bool = True,
        visibility_radius: float = 600.0,
        render_mode: Optional[str] = None,
    ) -> None:
        if int(n_blue) < 1:
            raise ValueError(f"n_blue must be >= 1, got {n_blue}")
        if int(n_red) < 1:
            raise ValueError(f"n_red must be >= 1, got {n_red}")
        if options is None and int(temporal_ratio) < 1:
            raise ValueError(f"temporal_ratio must be >= 1, got {temporal_ratio}")

        self.n_blue = int(n_blue)
        self.n_red = int(n_red)
        self.map_width = float(map_width)
        self.map_height = float(map_height)
        self.map_diagonal = math.hypot(self.map_width, self.map_height)
        self.max_steps = int(max_steps)
        self.red_random = bool(red_random)
        self.render_mode = render_mode
        self.temporal_ratio: int = int(temporal_ratio)

        # Option vocabulary — use temporal_ratio as max_steps when no custom
        # options are provided so the hyperparameter takes effect.
        self._options: list[Option] = (
            list(options) if options is not None
            else make_default_options(max_steps=self.temporal_ratio)
        )
        if len(self._options) == 0:
            raise ValueError("options must contain at least one Option.")
        self.n_options: int = len(self._options)

        # ── Action space ──────────────────────────────────────────────────
        # One option index per blue battalion
        self.action_space = spaces.MultiDiscrete(
            [self.n_options] * self.n_blue, dtype=np.int64
        )

        # ── Observation space ─────────────────────────────────────────────
        self._obs_dim: int = _brigade_obs_dim(self.n_blue)
        obs_low, obs_high = self._build_obs_bounds()
        self.observation_space = spaces.Box(
            low=obs_low, high=obs_high, dtype=np.float32
        )

        # ── Inner environment ─────────────────────────────────────────────
        self._inner = MultiBattalionEnv(
            n_blue=self.n_blue,
            n_red=self.n_red,
            map_width=self.map_width,
            map_height=self.map_height,
            max_steps=self.max_steps,
            randomize_terrain=randomize_terrain,
            visibility_radius=visibility_radius,
            render_mode=render_mode,
        )

        # ── Frozen battalion policy (optional) ────────────────────────────
        self._battalion_policy = None
        self._policy_device: str = "cpu"
        if battalion_policy is not None:
            self.set_battalion_policy(battalion_policy)

        # ── Episode state (populated by reset()) ─────────────────────────
        self._last_obs: dict[str, np.ndarray] = {}
        self._prim_steps: int = 0
        self._macro_steps: int = 0

        # ── Red option overrides (set externally by DivisionEnv) ──────────
        # Maps red agent_id → option index.  When non-empty, _get_red_action
        # executes the corresponding Option primitive policy instead of the
        # default battalion-policy / random / zero behaviour.
        self._forced_red_options: dict[str, int] = {}

    # ------------------------------------------------------------------
    # Observation bounds
    # ------------------------------------------------------------------

    def _build_obs_bounds(self) -> tuple[np.ndarray, np.ndarray]:
        """Return ``(obs_low, obs_high)`` arrays for the observation space."""
        lows: list[float] = []
        highs: list[float] = []

        # Sector control: [0, 1] × 3
        lows.extend([0.0] * N_SECTORS)
        highs.extend([1.0] * N_SECTORS)

        # Per-blue battalion strength + morale: [0, 1] each
        for _ in range(self.n_blue):
            lows.extend([0.0, 0.0])
            highs.extend([1.0, 1.0])

        # Per-blue threat vector: [dist, cos, sin, e_str, e_mor]
        for _ in range(self.n_blue):
            # dist / map_diagonal in [0, 1]
            lows.append(0.0)
            highs.append(1.0)
            # cos(bearing) in [-1, 1]
            lows.append(-1.0)
            highs.append(1.0)
            # sin(bearing) in [-1, 1]
            lows.append(-1.0)
            highs.append(1.0)
            # enemy strength in [0, 1]
            lows.append(0.0)
            highs.append(1.0)
            # enemy morale in [0, 1]
            lows.append(0.0)
            highs.append(1.0)

        # Step progress: [0, 1]
        lows.append(0.0)
        highs.append(1.0)

        return np.array(lows, dtype=np.float32), np.array(highs, dtype=np.float32)

    # ------------------------------------------------------------------
    # Frozen battalion policy
    # ------------------------------------------------------------------

    def set_battalion_policy(self, policy) -> None:
        """Set (or clear) the frozen policy used to drive Red agents.

        When a policy is supplied its parameters are frozen
        (``requires_grad=False``) and placed in evaluation mode so
        no gradients flow through it during brigade training.

        Parameters
        ----------
        policy:
            A :class:`~models.mappo_policy.MAPPOPolicy` instance, or
            ``None`` to revert to the default stationary / random Red
            behaviour.
        """
        if policy is None:
            self._battalion_policy = None
            self._policy_device = "cpu"
            return

        # Freeze all parameters
        for param in policy.parameters():
            param.requires_grad_(False)
        policy.eval()
        self._battalion_policy = policy
        # Store the device so _get_red_action can move obs tensors to it
        try:
            self._policy_device = next(policy.parameters()).device.type
        except StopIteration:
            self._policy_device = "cpu"

    # ------------------------------------------------------------------
    # Gymnasium API: reset
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ) -> tuple[np.ndarray, dict]:
        """Reset the environment and return the initial brigade observation.

        Parameters
        ----------
        seed:
            RNG seed forwarded to the inner :class:`~envs.multi_battalion_env.MultiBattalionEnv`.
        options:
            Unused; present for Gymnasium API compatibility.

        Returns
        -------
        obs : np.ndarray of shape ``(obs_dim,)``
        info : dict
        """
        if seed is not None:
            super().reset(seed=seed)

        inner_obs, _ = self._inner.reset(seed=seed, options=options)
        self._last_obs = dict(inner_obs)
        self._prim_steps = 0
        self._macro_steps = 0
        # Cleared here (between episodes) and also after each step by DivisionEnv
        # to prevent stale commands leaking into subsequent macro-steps.
        self._forced_red_options = {}
        return self._get_brigade_obs(), {}

    # ------------------------------------------------------------------
    # Gymnasium API: step
    # ------------------------------------------------------------------

    def step(
        self,
        brigade_action: np.ndarray,
    ) -> tuple[np.ndarray, float, bool, bool, dict]:
        """Execute one macro-step: dispatch options to Blue battalions.

        For each alive Blue battalion, the selected option runs for multiple
        primitive steps until the option terminates or the underlying
        episode ends.

        Parameters
        ----------
        brigade_action:
            Array of shape ``(n_blue,)`` with option indices.  Indices for
            dead battalions are ignored.

        Returns
        -------
        obs : np.ndarray — brigade observation after the macro-step
        reward : float — mean reward across all Blue battalions
        terminated : bool — True when the episode ended naturally
        truncated : bool — True when the episode was cut short
        info : dict — metadata including primitive step count and option names
        """
        brigade_action = np.asarray(brigade_action, dtype=np.int64)

        # Validate action shape
        if brigade_action.shape != (self.n_blue,):
            raise ValueError(
                f"brigade_action has shape {brigade_action.shape!r}, "
                f"expected ({self.n_blue},)."
            )

        # Active blue agents at the start of this macro-step
        current_blue = [
            f"blue_{i}" for i in range(self.n_blue)
            if f"blue_{i}" in self._inner.agents
        ]

        if not current_blue:
            # All Blue battalions already dead — episode over
            obs = self._get_brigade_obs()
            return obs, 0.0, True, False, {}

        # ── Dispatch options ──────────────────────────────────────────
        selected_options: dict[str, Option] = {}
        option_steps: dict[str, int] = {}
        option_names: dict[str, str] = {}
        for i in range(self.n_blue):
            agent_id = f"blue_{i}"
            if agent_id in current_blue:
                idx = int(brigade_action[i])
                if idx < 0 or idx >= self.n_options:
                    raise ValueError(
                        f"Invalid macro-action index {idx!r} for battalion {agent_id!r}; "
                        f"expected integer in [0, {self.n_options - 1}]."
                    )
                selected_options[agent_id] = self._options[idx]
                option_steps[agent_id] = 0
                option_names[agent_id] = self._options[idx].name

        option_done: dict[str, bool] = {a: False for a in current_blue}
        agg_rewards: dict[str, float] = {a: 0.0 for a in current_blue}
        ep_terminated: dict[str, bool] = {a: False for a in current_blue}
        ep_truncated: dict[str, bool] = {a: False for a in current_blue}
        # Track whether the inner env issued any truncation during this macro-step
        any_inner_truncated: bool = False

        # ── Inner primitive-step loop ─────────────────────────────────
        while any(not option_done[a] for a in current_blue):
            if not self._inner.agents:
                for a in current_blue:
                    option_done[a] = True
                break

            # Build primitive actions for all alive agents
            prim_actions: dict[str, np.ndarray] = {}
            for agent in self._inner.agents:
                if agent.startswith("blue_"):
                    if agent in current_blue and not option_done[agent]:
                        prim_actions[agent] = selected_options[agent].get_action(
                            self._last_obs[agent]
                        )
                    else:
                        prim_actions[agent] = np.zeros(_PRIM_ACT_DIM, dtype=np.float32)
                else:
                    # Red agent: driven by battalion policy, random, or zero
                    prim_actions[agent] = self._get_red_action(agent)

            # Primitive step
            obs, rewards, terminated, truncated, _ = self._inner.step(prim_actions)
            self._prim_steps += 1

            # Update latest observations
            for agent, ob in obs.items():
                self._last_obs[agent] = ob

            # Update option tracking — always record env-level
            # termination/truncation regardless of option_done state
            for agent in current_blue:
                env_term = bool(terminated.get(agent, False))
                env_trunc = bool(truncated.get(agent, False))

                if env_trunc:
                    any_inner_truncated = True
                # Always update episode-level flags
                if env_term:
                    ep_terminated[agent] = True
                if env_trunc:
                    ep_truncated[agent] = True

                # Skip option bookkeeping once the option is done
                if option_done[agent]:
                    continue

                agg_rewards[agent] += float(rewards.get(agent, 0.0))

                if env_term or env_trunc:
                    option_done[agent] = True
                else:
                    option_steps[agent] += 1
                    if selected_options[agent].should_terminate(
                        self._last_obs[agent], option_steps[agent]
                    ):
                        option_done[agent] = True

        self._macro_steps += 1

        # ── Episode termination ───────────────────────────────────────
        blue_alive = [
            f"blue_{i}" for i in range(self.n_blue)
            if f"blue_{i}" in self._inner._alive
        ]
        red_alive = [
            f"red_{i}" for i in range(self.n_red)
            if f"red_{i}" in self._inner._alive
        ]
        blue_wiped = len(blue_alive) == 0
        red_wiped = len(red_alive) == 0

        # Decisive combat outcome → terminated; time limit without decisive outcome → truncated
        if blue_wiped or red_wiped:
            episode_terminated = True
            episode_truncated = False
        elif any_inner_truncated:
            episode_terminated = False
            episode_truncated = True
        else:
            episode_terminated = False
            episode_truncated = False

        # ── Brigade reward ────────────────────────────────────────────
        reward_vals = [agg_rewards[a] for a in current_blue]
        brigade_reward = float(np.mean(reward_vals)) if reward_vals else 0.0

        # ── Info dict ─────────────────────────────────────────────────
        info: dict = {
            "macro_steps": self._macro_steps,
            "primitive_steps": self._prim_steps,
            "option_names": option_names,
            "option_steps": {a: option_steps.get(a, 0) for a in current_blue},
            "blue_rewards": {a: agg_rewards[a] for a in current_blue},
        }
        if episode_terminated or episode_truncated:
            if red_wiped and not blue_wiped:
                info["winner"] = "blue"
            elif blue_wiped and not red_wiped:
                info["winner"] = "red"
            else:
                info["winner"] = "draw"

        return (
            self._get_brigade_obs(),
            brigade_reward,
            episode_terminated,
            episode_truncated,
            info,
        )

    # ------------------------------------------------------------------
    # Brigade observation construction
    # ------------------------------------------------------------------

    def _get_brigade_obs(self) -> np.ndarray:
        """Build and return the normalised brigade observation vector."""
        parts: list[float] = []

        # ── 1. Sector control (3 vertical strips) ────────────────────
        sector_width = self.map_width / N_SECTORS
        for s in range(N_SECTORS):
            x_lo = s * sector_width
            x_hi = (s + 1) * sector_width
            blue_str = 0.0
            red_str = 0.0
            for agent_id, b in self._inner._battalions.items():
                if agent_id not in self._inner._alive:
                    continue
                if x_lo <= b.x < x_hi or (s == N_SECTORS - 1 and b.x == self.map_width):
                    if agent_id.startswith("blue_"):
                        blue_str += float(b.strength)
                    else:
                        red_str += float(b.strength)
            total = blue_str + red_str
            parts.append(blue_str / total if total > 0.0 else 0.5)

        # ── 2. Per-blue battalion strength + morale ───────────────────
        for i in range(self.n_blue):
            agent_id = f"blue_{i}"
            if agent_id in self._inner._battalions and agent_id in self._inner._alive:
                b = self._inner._battalions[agent_id]
                parts.append(float(b.strength))
                parts.append(float(b.morale))
            else:
                parts.extend([0.0, 0.0])

        # ── 3. Per-blue enemy threat vector ───────────────────────────
        # Alive red battalions
        alive_red = [
            (r_id, self._inner._battalions[r_id])
            for r_id in self._inner._alive
            if r_id.startswith("red_") and r_id in self._inner._battalions
        ]

        for i in range(self.n_blue):
            agent_id = f"blue_{i}"
            if agent_id not in self._inner._alive or agent_id not in self._inner._battalions:
                # Dead battalion — sentinel threat
                parts.extend([1.0, 0.0, 0.0, 0.0, 0.0])
                continue

            b = self._inner._battalions[agent_id]

            if not alive_red:
                # No enemies — maximum distance sentinel
                parts.extend([1.0, 0.0, 0.0, 0.0, 0.0])
                continue

            # Find nearest alive red battalion
            best_dist = float("inf")
            best_r = None
            for r_id, r_bat in alive_red:
                dx = r_bat.x - b.x
                dy = r_bat.y - b.y
                d = math.sqrt(dx * dx + dy * dy)
                if d < best_dist:
                    best_dist = d
                    best_r = r_bat

            assert best_r is not None
            dx = best_r.x - b.x
            dy = best_r.y - b.y
            bearing = math.atan2(dy, dx)

            parts.append(min(best_dist / self.map_diagonal, 1.0))
            parts.append(math.cos(bearing))
            parts.append(math.sin(bearing))
            parts.append(float(best_r.strength))
            parts.append(float(best_r.morale))

        # ── 4. Step progress ─────────────────────────────────────────
        parts.append(min(self._inner._step_count / self.max_steps, 1.0))

        obs = np.array(parts, dtype=np.float32)
        return np.clip(obs, self.observation_space.low, self.observation_space.high)

    # ------------------------------------------------------------------
    # Red action helper
    # ------------------------------------------------------------------

    def _get_red_action(self, agent_id: str) -> np.ndarray:
        """Return a primitive action for a Red agent.

        Priority:
        1. :attr:`_forced_red_options` (set by DivisionEnv) — execute the option's
           primitive policy directly.
        2. :attr:`_battalion_policy` (frozen MAPPOPolicy) — if set.
        3. Random primitive action — when ``red_random=True``.
        4. Zero (stationary) action — default.
        """
        if self._forced_red_options and agent_id in self._forced_red_options:
            opt_idx = int(self._forced_red_options[agent_id])
            if opt_idx < 0 or opt_idx >= self.n_options:
                raise ValueError(
                    f"Invalid forced option index {opt_idx!r} for Red agent {agent_id!r}; "
                    f"expected integer in [0, {self.n_options - 1}]."
                )
            obs = self._last_obs.get(
                agent_id, np.zeros(self._inner._obs_dim, dtype=np.float32)
            )
            return self._options[opt_idx].get_action(obs)

        if self._battalion_policy is not None:
            obs = self._last_obs.get(
                agent_id, np.zeros(self._inner._obs_dim, dtype=np.float32)
            )
            obs_t = torch.as_tensor(obs, dtype=torch.float32).unsqueeze(0).to(self._policy_device)
            # Infer agent index from the agent_id
            try:
                agent_idx = int(agent_id.split("_")[1]) % self._battalion_policy.n_agents
            except (IndexError, ValueError):
                agent_idx = 0
            with torch.no_grad():
                acts_t, _ = self._battalion_policy.act(
                    obs_t, agent_idx=agent_idx, deterministic=False
                )
            act = acts_t[0].cpu().numpy()
            act_low = self._inner._act_space.low
            act_high = self._inner._act_space.high
            return np.clip(act, act_low, act_high).astype(np.float32)

        if self.red_random:
            return self._inner.action_space(agent_id).sample()

        return np.zeros(_PRIM_ACT_DIM, dtype=np.float32)

    # ------------------------------------------------------------------
    # Gymnasium API: render / close
    # ------------------------------------------------------------------

    def render(self):
        """Delegate rendering to the inner environment."""
        return self._inner.render()

    def close(self) -> None:
        """Delegate cleanup to the inner environment."""
        self._inner.close()

close()

Delegate cleanup to the inner environment.

Source code in envs/brigade_env.py
def close(self) -> None:
    """Delegate cleanup to the inner environment."""
    self._inner.close()

render()

Delegate rendering to the inner environment.

Source code in envs/brigade_env.py
def render(self):
    """Delegate rendering to the inner environment."""
    return self._inner.render()

reset(seed=None, options=None)

Reset the environment and return the initial brigade observation.

Parameters:

Name Type Description Default
seed Optional[int]

RNG seed forwarded to the inner :class:~envs.multi_battalion_env.MultiBattalionEnv.

None
options Optional[dict]

Unused; present for Gymnasium API compatibility.

None

Returns:

Name Type Description
obs np.ndarray of shape ``(obs_dim,)``
info dict
Source code in envs/brigade_env.py
def reset(
    self,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> tuple[np.ndarray, dict]:
    """Reset the environment and return the initial brigade observation.

    Parameters
    ----------
    seed:
        RNG seed forwarded to the inner :class:`~envs.multi_battalion_env.MultiBattalionEnv`.
    options:
        Unused; present for Gymnasium API compatibility.

    Returns
    -------
    obs : np.ndarray of shape ``(obs_dim,)``
    info : dict
    """
    if seed is not None:
        super().reset(seed=seed)

    inner_obs, _ = self._inner.reset(seed=seed, options=options)
    self._last_obs = dict(inner_obs)
    self._prim_steps = 0
    self._macro_steps = 0
    # Cleared here (between episodes) and also after each step by DivisionEnv
    # to prevent stale commands leaking into subsequent macro-steps.
    self._forced_red_options = {}
    return self._get_brigade_obs(), {}

set_battalion_policy(policy)

Set (or clear) the frozen policy used to drive Red agents.

When a policy is supplied its parameters are frozen (requires_grad=False) and placed in evaluation mode so no gradients flow through it during brigade training.

Parameters:

Name Type Description Default
policy

A :class:~models.mappo_policy.MAPPOPolicy instance, or None to revert to the default stationary / random Red behaviour.

required
Source code in envs/brigade_env.py
def set_battalion_policy(self, policy) -> None:
    """Set (or clear) the frozen policy used to drive Red agents.

    When a policy is supplied its parameters are frozen
    (``requires_grad=False``) and placed in evaluation mode so
    no gradients flow through it during brigade training.

    Parameters
    ----------
    policy:
        A :class:`~models.mappo_policy.MAPPOPolicy` instance, or
        ``None`` to revert to the default stationary / random Red
        behaviour.
    """
    if policy is None:
        self._battalion_policy = None
        self._policy_device = "cpu"
        return

    # Freeze all parameters
    for param in policy.parameters():
        param.requires_grad_(False)
    policy.eval()
    self._battalion_policy = policy
    # Store the device so _get_red_action can move obs tensors to it
    try:
        self._policy_device = next(policy.parameters()).device.type
    except StopIteration:
        self._policy_device = "cpu"

step(brigade_action)

Execute one macro-step: dispatch options to Blue battalions.

For each alive Blue battalion, the selected option runs for multiple primitive steps until the option terminates or the underlying episode ends.

Parameters:

Name Type Description Default
brigade_action ndarray

Array of shape (n_blue,) with option indices. Indices for dead battalions are ignored.

required

Returns:

Name Type Description
obs np.ndarray — brigade observation after the macro-step
reward float — mean reward across all Blue battalions
terminated bool — True when the episode ended naturally
truncated bool — True when the episode was cut short
info dict — metadata including primitive step count and option names
Source code in envs/brigade_env.py
def step(
    self,
    brigade_action: np.ndarray,
) -> tuple[np.ndarray, float, bool, bool, dict]:
    """Execute one macro-step: dispatch options to Blue battalions.

    For each alive Blue battalion, the selected option runs for multiple
    primitive steps until the option terminates or the underlying
    episode ends.

    Parameters
    ----------
    brigade_action:
        Array of shape ``(n_blue,)`` with option indices.  Indices for
        dead battalions are ignored.

    Returns
    -------
    obs : np.ndarray — brigade observation after the macro-step
    reward : float — mean reward across all Blue battalions
    terminated : bool — True when the episode ended naturally
    truncated : bool — True when the episode was cut short
    info : dict — metadata including primitive step count and option names
    """
    brigade_action = np.asarray(brigade_action, dtype=np.int64)

    # Validate action shape
    if brigade_action.shape != (self.n_blue,):
        raise ValueError(
            f"brigade_action has shape {brigade_action.shape!r}, "
            f"expected ({self.n_blue},)."
        )

    # Active blue agents at the start of this macro-step
    current_blue = [
        f"blue_{i}" for i in range(self.n_blue)
        if f"blue_{i}" in self._inner.agents
    ]

    if not current_blue:
        # All Blue battalions already dead — episode over
        obs = self._get_brigade_obs()
        return obs, 0.0, True, False, {}

    # ── Dispatch options ──────────────────────────────────────────
    selected_options: dict[str, Option] = {}
    option_steps: dict[str, int] = {}
    option_names: dict[str, str] = {}
    for i in range(self.n_blue):
        agent_id = f"blue_{i}"
        if agent_id in current_blue:
            idx = int(brigade_action[i])
            if idx < 0 or idx >= self.n_options:
                raise ValueError(
                    f"Invalid macro-action index {idx!r} for battalion {agent_id!r}; "
                    f"expected integer in [0, {self.n_options - 1}]."
                )
            selected_options[agent_id] = self._options[idx]
            option_steps[agent_id] = 0
            option_names[agent_id] = self._options[idx].name

    option_done: dict[str, bool] = {a: False for a in current_blue}
    agg_rewards: dict[str, float] = {a: 0.0 for a in current_blue}
    ep_terminated: dict[str, bool] = {a: False for a in current_blue}
    ep_truncated: dict[str, bool] = {a: False for a in current_blue}
    # Track whether the inner env issued any truncation during this macro-step
    any_inner_truncated: bool = False

    # ── Inner primitive-step loop ─────────────────────────────────
    while any(not option_done[a] for a in current_blue):
        if not self._inner.agents:
            for a in current_blue:
                option_done[a] = True
            break

        # Build primitive actions for all alive agents
        prim_actions: dict[str, np.ndarray] = {}
        for agent in self._inner.agents:
            if agent.startswith("blue_"):
                if agent in current_blue and not option_done[agent]:
                    prim_actions[agent] = selected_options[agent].get_action(
                        self._last_obs[agent]
                    )
                else:
                    prim_actions[agent] = np.zeros(_PRIM_ACT_DIM, dtype=np.float32)
            else:
                # Red agent: driven by battalion policy, random, or zero
                prim_actions[agent] = self._get_red_action(agent)

        # Primitive step
        obs, rewards, terminated, truncated, _ = self._inner.step(prim_actions)
        self._prim_steps += 1

        # Update latest observations
        for agent, ob in obs.items():
            self._last_obs[agent] = ob

        # Update option tracking — always record env-level
        # termination/truncation regardless of option_done state
        for agent in current_blue:
            env_term = bool(terminated.get(agent, False))
            env_trunc = bool(truncated.get(agent, False))

            if env_trunc:
                any_inner_truncated = True
            # Always update episode-level flags
            if env_term:
                ep_terminated[agent] = True
            if env_trunc:
                ep_truncated[agent] = True

            # Skip option bookkeeping once the option is done
            if option_done[agent]:
                continue

            agg_rewards[agent] += float(rewards.get(agent, 0.0))

            if env_term or env_trunc:
                option_done[agent] = True
            else:
                option_steps[agent] += 1
                if selected_options[agent].should_terminate(
                    self._last_obs[agent], option_steps[agent]
                ):
                    option_done[agent] = True

    self._macro_steps += 1

    # ── Episode termination ───────────────────────────────────────
    blue_alive = [
        f"blue_{i}" for i in range(self.n_blue)
        if f"blue_{i}" in self._inner._alive
    ]
    red_alive = [
        f"red_{i}" for i in range(self.n_red)
        if f"red_{i}" in self._inner._alive
    ]
    blue_wiped = len(blue_alive) == 0
    red_wiped = len(red_alive) == 0

    # Decisive combat outcome → terminated; time limit without decisive outcome → truncated
    if blue_wiped or red_wiped:
        episode_terminated = True
        episode_truncated = False
    elif any_inner_truncated:
        episode_terminated = False
        episode_truncated = True
    else:
        episode_terminated = False
        episode_truncated = False

    # ── Brigade reward ────────────────────────────────────────────
    reward_vals = [agg_rewards[a] for a in current_blue]
    brigade_reward = float(np.mean(reward_vals)) if reward_vals else 0.0

    # ── Info dict ─────────────────────────────────────────────────
    info: dict = {
        "macro_steps": self._macro_steps,
        "primitive_steps": self._prim_steps,
        "option_names": option_names,
        "option_steps": {a: option_steps.get(a, 0) for a in current_blue},
        "blue_rewards": {a: agg_rewards[a] for a in current_blue},
    }
    if episode_terminated or episode_truncated:
        if red_wiped and not blue_wiped:
            info["winner"] = "blue"
        elif blue_wiped and not red_wiped:
            info["winner"] = "red"
        else:
            info["winner"] = "draw"

    return (
        self._get_brigade_obs(),
        brigade_reward,
        episode_terminated,
        episode_truncated,
        info,
    )

envs.division_env.DivisionEnv

Bases: Env

Gymnasium environment for a division-level HRL commander.

Parameters:

Name Type Description Default
n_brigades int

Number of Blue brigades. Each brigade is a group of n_blue_per_brigade battalions.

2
n_blue_per_brigade int

Number of Blue battalions per brigade.

2
n_red_brigades Optional[int]

Number of Red brigades. Defaults to n_brigades.

None
n_red_per_brigade Optional[int]

Number of Red battalions per Red brigade. Defaults to n_blue_per_brigade.

None
map_width float

Map width in metres (passed through to inner env).

MAP_WIDTH
map_height float

Map height in metres (passed through to inner env).

MAP_HEIGHT
max_steps int

Maximum primitive-step episode length.

MAX_STEPS
brigade_policy

Optional frozen brigade-level policy for Red brigades. Must expose predict(obs, deterministic) -> (action, state). When provided all its parameters should have requires_grad=False.

None
red_random bool

When True and no brigade_policy is set, Red battalions take random primitive actions. Ignored when brigade_policy is set.

False
randomize_terrain bool

Pass-through to :class:~envs.brigade_env.BrigadeEnv.

True
visibility_radius float

Fog-of-war visibility radius in metres.

600.0
render_mode Optional[str]

None or "human" — delegated to the inner env.

None
Source code in envs/division_env.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
class DivisionEnv(gym.Env):
    """Gymnasium environment for a division-level HRL commander.

    Parameters
    ----------
    n_brigades:
        Number of Blue brigades.  Each brigade is a group of
        *n_blue_per_brigade* battalions.
    n_blue_per_brigade:
        Number of Blue battalions per brigade.
    n_red_brigades:
        Number of Red brigades.  Defaults to *n_brigades*.
    n_red_per_brigade:
        Number of Red battalions per Red brigade.
        Defaults to *n_blue_per_brigade*.
    map_width:
        Map width in metres (passed through to inner env).
    map_height:
        Map height in metres (passed through to inner env).
    max_steps:
        Maximum primitive-step episode length.
    brigade_policy:
        Optional frozen brigade-level policy for Red brigades.
        Must expose ``predict(obs, deterministic) -> (action, state)``.
        When provided all its parameters should have ``requires_grad=False``.
    red_random:
        When ``True`` and no *brigade_policy* is set, Red battalions take
        random primitive actions.  Ignored when *brigade_policy* is set.
    randomize_terrain:
        Pass-through to :class:`~envs.brigade_env.BrigadeEnv`.
    visibility_radius:
        Fog-of-war visibility radius in metres.
    render_mode:
        ``None`` or ``"human"`` — delegated to the inner env.
    """

    metadata: dict = {"render_modes": ["human"], "name": "division_v0"}

    def __init__(
        self,
        n_brigades: int = 2,
        n_blue_per_brigade: int = 2,
        n_red_brigades: Optional[int] = None,
        n_red_per_brigade: Optional[int] = None,
        map_width: float = MAP_WIDTH,
        map_height: float = MAP_HEIGHT,
        max_steps: int = MAX_STEPS,
        brigade_policy=None,
        red_random: bool = False,
        randomize_terrain: bool = True,
        visibility_radius: float = 600.0,
        render_mode: Optional[str] = None,
    ) -> None:
        if int(n_brigades) < 1:
            raise ValueError(f"n_brigades must be >= 1, got {n_brigades}")
        if int(n_blue_per_brigade) < 1:
            raise ValueError(f"n_blue_per_brigade must be >= 1, got {n_blue_per_brigade}")

        self.n_brigades: int = int(n_brigades)
        self.n_blue_per_brigade: int = int(n_blue_per_brigade)
        self.n_red_brigades: int = int(n_brigades if n_red_brigades is None else n_red_brigades)
        self.n_red_per_brigade: int = int(
            n_blue_per_brigade if n_red_per_brigade is None else n_red_per_brigade
        )

        if self.n_red_brigades < 1:
            raise ValueError(f"n_red_brigades must be >= 1, got {self.n_red_brigades}")
        if self.n_red_per_brigade < 1:
            raise ValueError(f"n_red_per_brigade must be >= 1, got {self.n_red_per_brigade}")

        # Total battalion counts
        self.n_blue: int = self.n_brigades * self.n_blue_per_brigade
        self.n_red: int = self.n_red_brigades * self.n_red_per_brigade

        self.map_width = float(map_width)
        self.map_height = float(map_height)
        self.map_diagonal = math.hypot(self.map_width, self.map_height)
        self.max_steps = int(max_steps)
        self.red_random = bool(red_random)
        self.render_mode = render_mode

        # ── Inner BrigadeEnv ─────────────────────────────────────────────
        self._brigade = BrigadeEnv(
            n_blue=self.n_blue,
            n_red=self.n_red,
            map_width=self.map_width,
            map_height=self.map_height,
            max_steps=self.max_steps,
            red_random=red_random,
            randomize_terrain=randomize_terrain,
            visibility_radius=visibility_radius,
            render_mode=render_mode,
        )

        # n_div_options matches the brigade option count
        self.n_div_options: int = self._brigade.n_options

        # ── Action space ────────────────────────────────────────────────
        # One operational command per Blue brigade
        self.action_space = spaces.MultiDiscrete(
            [self.n_div_options] * self.n_brigades, dtype=np.int64
        )

        # ── Observation space ───────────────────────────────────────────
        self._obs_dim: int = _division_obs_dim(self.n_brigades)
        obs_low, obs_high = self._build_obs_bounds()
        self.observation_space = spaces.Box(
            low=obs_low, high=obs_high, dtype=np.float32
        )

        # ── Frozen brigade policy for Red (optional) ────────────────────
        self._red_brigade_policy = None
        if brigade_policy is not None:
            self.set_brigade_policy(brigade_policy)

        # ── Episode state ────────────────────────────────────────────────
        self._div_steps: int = 0

    # ------------------------------------------------------------------
    # Observation bounds
    # ------------------------------------------------------------------

    def _build_obs_bounds(self) -> tuple[np.ndarray, np.ndarray]:
        """Return ``(obs_low, obs_high)`` arrays for the observation space."""
        lows: list[float] = []
        highs: list[float] = []

        # Theatre sector control: [0, 1] × N_THEATRE_SECTORS
        lows.extend([0.0] * N_THEATRE_SECTORS)
        highs.extend([1.0] * N_THEATRE_SECTORS)

        # Per-brigade status: [avg_strength, avg_morale, alive_ratio]
        for _ in range(self.n_brigades):
            lows.extend([0.0, 0.0, 0.0])
            highs.extend([1.0, 1.0, 1.0])

        # Per-brigade threat vector: [dist, cos, sin, e_str, e_mor]
        for _ in range(self.n_brigades):
            lows.append(0.0)    # dist / diagonal
            highs.append(1.0)
            lows.append(-1.0)   # cos(bearing)
            highs.append(1.0)
            lows.append(-1.0)   # sin(bearing)
            highs.append(1.0)
            lows.append(0.0)    # enemy avg_strength
            highs.append(1.0)
            lows.append(0.0)    # enemy avg_morale
            highs.append(1.0)

        # Step progress: [0, 1]
        lows.append(0.0)
        highs.append(1.0)

        return np.array(lows, dtype=np.float32), np.array(highs, dtype=np.float32)

    # ------------------------------------------------------------------
    # Frozen brigade policy
    # ------------------------------------------------------------------

    def set_brigade_policy(self, policy) -> None:
        """Set (or clear) the frozen brigade policy for Red.

        When *policy* is supplied any PyTorch parameters are frozen
        (``requires_grad=False``) and placed in evaluation mode.

        Parameters
        ----------
        policy:
            An object with a ``predict(obs, deterministic)`` method
            (e.g. an SB3 :class:`~stable_baselines3.PPO` model), or
            ``None`` to revert to the default Red behaviour.
        """
        if policy is None:
            self._red_brigade_policy = None
            return

        # Freeze parameters if this is a PyTorch module
        if hasattr(policy, "parameters"):
            for param in policy.parameters():
                param.requires_grad_(False)
        if hasattr(policy, "eval"):
            policy.eval()

        # Freeze SB3 policy networks if accessible
        if hasattr(policy, "policy") and hasattr(policy.policy, "parameters"):
            for param in policy.policy.parameters():
                param.requires_grad_(False)

        self._red_brigade_policy = policy

    # ------------------------------------------------------------------
    # Gymnasium API: reset
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ) -> tuple[np.ndarray, dict]:
        """Reset the environment and return the initial division observation.

        Parameters
        ----------
        seed:
            RNG seed forwarded to the inner :class:`~envs.brigade_env.BrigadeEnv`.
        options:
            Unused; present for Gymnasium API compatibility.

        Returns
        -------
        obs : np.ndarray of shape ``(obs_dim,)``
        info : dict
        """
        if seed is not None:
            super().reset(seed=seed)

        self._brigade.reset(seed=seed, options=options)
        self._div_steps = 0
        return self._get_division_obs(), {}

    # ------------------------------------------------------------------
    # Gymnasium API: step
    # ------------------------------------------------------------------

    def step(
        self,
        div_action: np.ndarray,
    ) -> tuple[np.ndarray, float, bool, bool, dict]:
        """Execute one division macro-step.

        Translates the division operational command for each brigade into
        brigade-level actions (option indices for all battalions in the
        brigade) and delegates to :meth:`~envs.brigade_env.BrigadeEnv.step`.

        Parameters
        ----------
        div_action:
            Array of shape ``(n_brigades,)`` with operational command indices.
            Indices for brigades whose battalions are all dead are ignored.

        Returns
        -------
        obs : np.ndarray — division observation after the macro-step
        reward : float — brigade reward (passed through from BrigadeEnv)
        terminated : bool
        truncated : bool
        info : dict — includes ``div_steps``, ``brigade_action``, and
            fields from :class:`~envs.brigade_env.BrigadeEnv`
        """
        div_action = np.asarray(div_action, dtype=np.int64)

        if div_action.shape != (self.n_brigades,):
            raise ValueError(
                f"div_action has shape {div_action.shape!r}, "
                f"expected ({self.n_brigades},)."
            )

        for i, cmd in enumerate(div_action):
            if int(cmd) < 0 or int(cmd) >= self.n_div_options:
                raise ValueError(
                    f"Invalid operational command {int(cmd)!r} for brigade {i}; "
                    f"expected integer in [0, {self.n_div_options - 1}]."
                )

        # ── Inject Red brigade commands (if frozen policy set) ────────
        if self._red_brigade_policy is not None:
            self._update_red_brigade_options()

        # ── Translate division commands → BrigadeEnv action ──────────
        # brigade_action[i*n_per + j] = div_action[i]  for j in range(n_per)
        brigade_action = self._translate_division_action(div_action)

        # ── Delegate to BrigadeEnv ────────────────────────────────────
        _brigade_obs, reward, terminated, truncated, brigade_info = (
            self._brigade.step(brigade_action)
        )

        # Clear forced Red options after the step
        self._brigade._forced_red_options = {}

        self._div_steps += 1

        info: dict = {
            "div_steps": self._div_steps,
            "brigade_action": brigade_action.tolist(),
        }
        info.update(brigade_info)

        return self._get_division_obs(), float(reward), terminated, truncated, info

    # ------------------------------------------------------------------
    # Command translation
    # ------------------------------------------------------------------

    def _translate_division_action(self, div_action: np.ndarray) -> np.ndarray:
        """Expand a per-brigade action to a flat per-battalion brigade action.

        Parameters
        ----------
        div_action:
            Integer array of shape ``(n_brigades,)`` with operational command
            indices.

        Returns
        -------
        np.ndarray of shape ``(n_blue,)`` — option index for every battalion.
        """
        brigade_action = np.empty(self.n_blue, dtype=np.int64)
        for i in range(self.n_brigades):
            start = i * self.n_blue_per_brigade
            end = start + self.n_blue_per_brigade
            brigade_action[start:end] = int(div_action[i])
        return brigade_action

    # ------------------------------------------------------------------
    # Red brigade policy injection
    # ------------------------------------------------------------------

    def _update_red_brigade_options(self) -> None:
        """Compute Red brigade obs, query the frozen policy, and inject options.

        The frozen brigade policy receives a :class:`~envs.brigade_env.BrigadeEnv`-
        compatible observation for the Red side (shape ``3 + 7 * n_red + 1``,
        treating Red battalions as the "blue" side) and returns a per-battalion
        action of shape ``(n_red,)`` with option indices in ``[0, n_options)``.
        Each option index is injected directly into
        :attr:`~envs.brigade_env.BrigadeEnv._forced_red_options` for the
        corresponding Red battalion, bypassing the default Red action logic.
        """
        red_obs = self._get_red_brigade_obs()
        red_action, _ = self._red_brigade_policy.predict(
            red_obs, deterministic=False
        )
        red_action = np.asarray(red_action, dtype=np.int64).flatten()

        if len(red_action) < self.n_red:
            raise ValueError(
                f"Frozen brigade policy returned action of length {len(red_action)}, "
                f"but expected at least {self.n_red} (one per Red battalion)."
            )

        forced: dict[str, int] = {}
        for idx in range(self.n_red):
            cmd = int(np.clip(red_action[idx], 0, self.n_div_options - 1))
            forced[f"red_{idx}"] = cmd
        self._brigade._forced_red_options = forced

    # ------------------------------------------------------------------
    # Division observation construction
    # ------------------------------------------------------------------

    def _battalion_in_sector(self, b, s: int, sector_width: float) -> bool:
        """Return True if battalion *b* occupies theatre sector *s*."""
        x_lo = s * sector_width
        x_hi = (s + 1) * sector_width
        return b.x >= x_lo and (
            b.x < x_hi or (s == N_THEATRE_SECTORS - 1 and b.x == self.map_width)
        )

    def _get_theatre_sector_strengths(self, inner) -> list[tuple[float, float]]:
        """Return ``[(blue_str, red_str), ...]`` for each theatre sector."""
        sector_width = self.map_width / N_THEATRE_SECTORS
        result: list[tuple[float, float]] = []
        for s in range(N_THEATRE_SECTORS):
            blue_str = 0.0
            red_str = 0.0
            for agent_id, b in inner._battalions.items():
                if agent_id not in inner._alive:
                    continue
                if self._battalion_in_sector(b, s, sector_width):
                    if agent_id.startswith("blue_"):
                        blue_str += float(b.strength)
                    else:
                        red_str += float(b.strength)
            result.append((blue_str, red_str))
        return result

    def _get_division_obs(self) -> np.ndarray:
        """Build and return the normalised division observation vector."""
        parts: list[float] = []
        inner = self._brigade._inner

        # ── 1. Theatre sector control (5 vertical strips) ─────────────
        for blue_str, red_str in self._get_theatre_sector_strengths(inner):
            total = blue_str + red_str
            parts.append(blue_str / total if total > 0.0 else 0.5)

        # ── 2. Per-brigade status [avg_strength, avg_morale, alive_ratio] ─
        for i in range(self.n_brigades):
            strengths = []
            morales = []
            alive_count = 0
            for j in range(self.n_blue_per_brigade):
                agent_id = f"blue_{i * self.n_blue_per_brigade + j}"
                if agent_id in inner._battalions and agent_id in inner._alive:
                    b = inner._battalions[agent_id]
                    strengths.append(float(b.strength))
                    morales.append(float(b.morale))
                    alive_count += 1
            avg_str = float(np.mean(strengths)) if strengths else 0.0
            avg_mor = float(np.mean(morales)) if morales else 0.0
            alive_ratio = alive_count / self.n_blue_per_brigade
            parts.extend([avg_str, avg_mor, alive_ratio])

        # ── 3. Per-brigade threat vector ───────────────────────────────
        # Build list of alive Red brigade centroids
        red_brigade_centroids = self._get_red_brigade_centroids(inner)

        for i in range(self.n_brigades):
            # Compute centroid of this Blue brigade
            bx_list, by_list = [], []
            for j in range(self.n_blue_per_brigade):
                agent_id = f"blue_{i * self.n_blue_per_brigade + j}"
                if agent_id in inner._battalions and agent_id in inner._alive:
                    b = inner._battalions[agent_id]
                    bx_list.append(b.x)
                    by_list.append(b.y)

            if not bx_list or not red_brigade_centroids:
                # This brigade is dead or no Red brigades alive — sentinel
                parts.extend([1.0, 0.0, 0.0, 0.0, 0.0])
                continue

            cx = float(np.mean(bx_list))
            cy = float(np.mean(by_list))

            # Find nearest Red brigade by centroid
            best_dist = float("inf")
            best_centroid = None
            best_e_str = 0.0
            best_e_mor = 0.0
            for (rx, ry, e_str, e_mor) in red_brigade_centroids:
                dx = rx - cx
                dy = ry - cy
                d = math.sqrt(dx * dx + dy * dy)
                if d < best_dist:
                    best_dist = d
                    best_centroid = (rx, ry)
                    best_e_str = e_str
                    best_e_mor = e_mor

            assert best_centroid is not None
            dx = best_centroid[0] - cx
            dy = best_centroid[1] - cy
            bearing = math.atan2(dy, dx)

            parts.append(min(best_dist / self.map_diagonal, 1.0))
            parts.append(math.cos(bearing))
            parts.append(math.sin(bearing))
            parts.append(best_e_str)
            parts.append(best_e_mor)

        # ── 4. Step progress ───────────────────────────────────────────
        parts.append(min(inner._step_count / self.max_steps, 1.0))

        obs = np.array(parts, dtype=np.float32)
        return np.clip(obs, self.observation_space.low, self.observation_space.high)

    def _get_red_brigade_centroids(
        self, inner
    ) -> list[tuple[float, float, float, float]]:
        """Return ``(cx, cy, avg_strength, avg_morale)`` for each alive Red brigade."""
        centroids = []
        for i in range(self.n_red_brigades):
            xs, ys, strs, mors = [], [], [], []
            for j in range(self.n_red_per_brigade):
                agent_id = f"red_{i * self.n_red_per_brigade + j}"
                if agent_id in inner._battalions and agent_id in inner._alive:
                    b = inner._battalions[agent_id]
                    xs.append(b.x)
                    ys.append(b.y)
                    strs.append(float(b.strength))
                    mors.append(float(b.morale))
            if xs:
                centroids.append((
                    float(np.mean(xs)),
                    float(np.mean(ys)),
                    float(np.mean(strs)),
                    float(np.mean(mors)),
                ))
        return centroids

    # ------------------------------------------------------------------
    # Red brigade observation (for frozen Red brigade policy)
    # ------------------------------------------------------------------

    def _get_red_brigade_obs(self) -> np.ndarray:
        """Build a :class:`~envs.brigade_env.BrigadeEnv`-compatible observation for Red.

        The observation mirrors the format that a brigade-level PPO policy
        was trained on, treating Red battalions as the "blue" side:

        * ``_BRIGADE_N_SECTORS`` (= 3) sector-control values — Red's strength
          share in each of 3 equal vertical strips.
        * Per-Red-battalion ``[strength, morale]`` — zeros for dead battalions.
        * Per-Red-battalion enemy threat ``[dist/diag, cos_bear, sin_bear,
          e_str, e_mor]`` — nearest alive Blue *battalion* (not centroid).
          Sentinel ``[1, 0, 0, 0, 0]`` when no Blue battalion is alive.
        * Step progress.

        The returned array has shape ``(_BRIGADE_N_SECTORS + 7 * n_red + 1,)``
        and is clipped using per-element bounds, independent of
        ``self.observation_space`` (which is sized for the Blue division obs).
        """
        parts: list[float] = []
        inner = self._brigade._inner

        # ── 1. Sector control — 3 strips, Red's share ─────────────────
        sector_width = self.map_width / _BRIGADE_N_SECTORS
        for s in range(_BRIGADE_N_SECTORS):
            x_lo = s * sector_width
            x_hi = (s + 1) * sector_width
            blue_str = 0.0
            red_str = 0.0
            for agent_id, b in inner._battalions.items():
                if agent_id not in inner._alive:
                    continue
                in_sector = (x_lo <= b.x < x_hi) or (
                    s == _BRIGADE_N_SECTORS - 1 and b.x == self.map_width
                )
                if in_sector:
                    if agent_id.startswith("blue_"):
                        blue_str += float(b.strength)
                    else:
                        red_str += float(b.strength)
            total = blue_str + red_str
            parts.append(red_str / total if total > 0.0 else 0.5)

        # ── 2. Per-Red-battalion strength + morale ─────────────────────
        for idx in range(self.n_red):
            agent_id = f"red_{idx}"
            if agent_id in inner._battalions and agent_id in inner._alive:
                b = inner._battalions[agent_id]
                parts.append(float(b.strength))
                parts.append(float(b.morale))
            else:
                parts.extend([0.0, 0.0])

        # ── 3. Per-Red-battalion enemy threat → nearest alive Blue ─────
        alive_blue = [
            (b_id, inner._battalions[b_id])
            for b_id in inner._alive
            if b_id.startswith("blue_") and b_id in inner._battalions
        ]

        for idx in range(self.n_red):
            agent_id = f"red_{idx}"
            if agent_id not in inner._alive or agent_id not in inner._battalions:
                parts.extend([1.0, 0.0, 0.0, 0.0, 0.0])
                continue

            rb = inner._battalions[agent_id]

            if not alive_blue:
                parts.extend([1.0, 0.0, 0.0, 0.0, 0.0])
                continue

            best_dist = float("inf")
            best_b = None
            for _b_id, bb in alive_blue:
                dx = bb.x - rb.x
                dy = bb.y - rb.y
                d = math.sqrt(dx * dx + dy * dy)
                if d < best_dist:
                    best_dist = d
                    best_b = bb

            assert best_b is not None
            dx = best_b.x - rb.x
            dy = best_b.y - rb.y
            bearing = math.atan2(dy, dx)

            parts.append(min(best_dist / self.map_diagonal, 1.0))
            parts.append(math.cos(bearing))
            parts.append(math.sin(bearing))
            parts.append(float(best_b.strength))
            parts.append(float(best_b.morale))

        # ── 4. Step progress ───────────────────────────────────────────
        parts.append(min(inner._step_count / self.max_steps, 1.0))

        obs = np.array(parts, dtype=np.float32)

        # Build per-element bounds for the BrigadeEnv-compatible obs layout.
        # This is independent of self.observation_space (Blue division obs).
        lows: list[float] = [0.0] * _BRIGADE_N_SECTORS
        highs: list[float] = [1.0] * _BRIGADE_N_SECTORS
        for _ in range(self.n_red):   # strength, morale per battalion
            lows.extend([0.0, 0.0])
            highs.extend([1.0, 1.0])
        for _ in range(self.n_red):   # dist, cos, sin, e_str, e_mor per battalion
            lows.extend([0.0, -1.0, -1.0, 0.0, 0.0])
            highs.extend([1.0,  1.0,  1.0, 1.0, 1.0])
        lows.append(0.0)
        highs.append(1.0)
        return np.clip(obs, np.array(lows, dtype=np.float32), np.array(highs, dtype=np.float32))

    # ------------------------------------------------------------------
    # Gymnasium API: render / close
    # ------------------------------------------------------------------

    def render(self):
        """Delegate rendering to the inner brigade environment."""
        return self._brigade.render()

    def close(self) -> None:
        """Delegate cleanup to the inner brigade environment."""
        self._brigade.close()

close()

Delegate cleanup to the inner brigade environment.

Source code in envs/division_env.py
def close(self) -> None:
    """Delegate cleanup to the inner brigade environment."""
    self._brigade.close()

render()

Delegate rendering to the inner brigade environment.

Source code in envs/division_env.py
def render(self):
    """Delegate rendering to the inner brigade environment."""
    return self._brigade.render()

reset(seed=None, options=None)

Reset the environment and return the initial division observation.

Parameters:

Name Type Description Default
seed Optional[int]

RNG seed forwarded to the inner :class:~envs.brigade_env.BrigadeEnv.

None
options Optional[dict]

Unused; present for Gymnasium API compatibility.

None

Returns:

Name Type Description
obs np.ndarray of shape ``(obs_dim,)``
info dict
Source code in envs/division_env.py
def reset(
    self,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> tuple[np.ndarray, dict]:
    """Reset the environment and return the initial division observation.

    Parameters
    ----------
    seed:
        RNG seed forwarded to the inner :class:`~envs.brigade_env.BrigadeEnv`.
    options:
        Unused; present for Gymnasium API compatibility.

    Returns
    -------
    obs : np.ndarray of shape ``(obs_dim,)``
    info : dict
    """
    if seed is not None:
        super().reset(seed=seed)

    self._brigade.reset(seed=seed, options=options)
    self._div_steps = 0
    return self._get_division_obs(), {}

set_brigade_policy(policy)

Set (or clear) the frozen brigade policy for Red.

When policy is supplied any PyTorch parameters are frozen (requires_grad=False) and placed in evaluation mode.

Parameters:

Name Type Description Default
policy

An object with a predict(obs, deterministic) method (e.g. an SB3 :class:~stable_baselines3.PPO model), or None to revert to the default Red behaviour.

required
Source code in envs/division_env.py
def set_brigade_policy(self, policy) -> None:
    """Set (or clear) the frozen brigade policy for Red.

    When *policy* is supplied any PyTorch parameters are frozen
    (``requires_grad=False``) and placed in evaluation mode.

    Parameters
    ----------
    policy:
        An object with a ``predict(obs, deterministic)`` method
        (e.g. an SB3 :class:`~stable_baselines3.PPO` model), or
        ``None`` to revert to the default Red behaviour.
    """
    if policy is None:
        self._red_brigade_policy = None
        return

    # Freeze parameters if this is a PyTorch module
    if hasattr(policy, "parameters"):
        for param in policy.parameters():
            param.requires_grad_(False)
    if hasattr(policy, "eval"):
        policy.eval()

    # Freeze SB3 policy networks if accessible
    if hasattr(policy, "policy") and hasattr(policy.policy, "parameters"):
        for param in policy.policy.parameters():
            param.requires_grad_(False)

    self._red_brigade_policy = policy

step(div_action)

Execute one division macro-step.

Translates the division operational command for each brigade into brigade-level actions (option indices for all battalions in the brigade) and delegates to :meth:~envs.brigade_env.BrigadeEnv.step.

Parameters:

Name Type Description Default
div_action ndarray

Array of shape (n_brigades,) with operational command indices. Indices for brigades whose battalions are all dead are ignored.

required

Returns:

Name Type Description
obs np.ndarray — division observation after the macro-step
reward float — brigade reward (passed through from BrigadeEnv)
terminated bool
truncated bool
info dict — includes ``div_steps``, ``brigade_action``, and

fields from :class:~envs.brigade_env.BrigadeEnv

Source code in envs/division_env.py
def step(
    self,
    div_action: np.ndarray,
) -> tuple[np.ndarray, float, bool, bool, dict]:
    """Execute one division macro-step.

    Translates the division operational command for each brigade into
    brigade-level actions (option indices for all battalions in the
    brigade) and delegates to :meth:`~envs.brigade_env.BrigadeEnv.step`.

    Parameters
    ----------
    div_action:
        Array of shape ``(n_brigades,)`` with operational command indices.
        Indices for brigades whose battalions are all dead are ignored.

    Returns
    -------
    obs : np.ndarray — division observation after the macro-step
    reward : float — brigade reward (passed through from BrigadeEnv)
    terminated : bool
    truncated : bool
    info : dict — includes ``div_steps``, ``brigade_action``, and
        fields from :class:`~envs.brigade_env.BrigadeEnv`
    """
    div_action = np.asarray(div_action, dtype=np.int64)

    if div_action.shape != (self.n_brigades,):
        raise ValueError(
            f"div_action has shape {div_action.shape!r}, "
            f"expected ({self.n_brigades},)."
        )

    for i, cmd in enumerate(div_action):
        if int(cmd) < 0 or int(cmd) >= self.n_div_options:
            raise ValueError(
                f"Invalid operational command {int(cmd)!r} for brigade {i}; "
                f"expected integer in [0, {self.n_div_options - 1}]."
            )

    # ── Inject Red brigade commands (if frozen policy set) ────────
    if self._red_brigade_policy is not None:
        self._update_red_brigade_options()

    # ── Translate division commands → BrigadeEnv action ──────────
    # brigade_action[i*n_per + j] = div_action[i]  for j in range(n_per)
    brigade_action = self._translate_division_action(div_action)

    # ── Delegate to BrigadeEnv ────────────────────────────────────
    _brigade_obs, reward, terminated, truncated, brigade_info = (
        self._brigade.step(brigade_action)
    )

    # Clear forced Red options after the step
    self._brigade._forced_red_options = {}

    self._div_steps += 1

    info: dict = {
        "div_steps": self._div_steps,
        "brigade_action": brigade_action.tolist(),
    }
    info.update(brigade_info)

    return self._get_division_obs(), float(reward), terminated, truncated, info

envs.corps_env.CorpsEnv

Bases: Env

Gymnasium environment for a corps-level HRL commander.

Parameters:

Name Type Description Default
n_divisions int

Number of Blue divisions. Must be ≥ 1.

3
n_brigades_per_division int

Number of Blue brigades per division. Must be ≥ 1.

3
n_blue_per_brigade int

Number of Blue battalions per brigade. Must be ≥ 1.

4
n_red_divisions Optional[int]

Number of Red divisions. Defaults to n_divisions.

None
n_red_brigades_per_division Optional[int]

Red brigades per division. Defaults to n_brigades_per_division.

None
n_red_per_brigade Optional[int]

Red battalions per brigade. Defaults to n_blue_per_brigade.

None
map_width float

Map dimensions in metres (default 10 km × 5 km = 50 km²).

CORPS_MAP_WIDTH
map_height float

Map dimensions in metres (default 10 km × 5 km = 50 km²).

CORPS_MAP_WIDTH
max_steps int

Maximum primitive-step episode length.

MAX_STEPS
road_network Optional[RoadNetwork]

Optional :class:~envs.sim.road_network.RoadNetwork. When None (the default) a standard network is generated via :meth:~envs.sim.road_network.RoadNetwork.generate_default.

None
supply_network Optional[SupplyNetwork]

Optional :class:~envs.sim.supply_network.SupplyNetwork. When None (the default) a standard bilateral supply network is generated via :meth:~envs.sim.supply_network.SupplyNetwork.generate_default.

None
objectives Optional[List[OperationalObjective]]

Optional list of :class:OperationalObjective. When None, three default objectives are placed automatically.

None
comm_radius float

Inter-division communication radius in metres. Enemy threat vectors beyond this distance are replaced with sentinels.

DEFAULT_COMM_RADIUS
red_random bool

When True Red takes random brigade actions.

True
randomize_terrain bool

Pass-through to the inner env.

True
visibility_radius float

Fog-of-war visibility radius in metres.

DEFAULT_VISIBILITY_RADIUS
render_mode Optional[str]

None or "human".

None
Source code in envs/corps_env.py
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
class CorpsEnv(gym.Env):
    """Gymnasium environment for a corps-level HRL commander.

    Parameters
    ----------
    n_divisions:
        Number of Blue divisions.  Must be ≥ 1.
    n_brigades_per_division:
        Number of Blue brigades per division.  Must be ≥ 1.
    n_blue_per_brigade:
        Number of Blue battalions per brigade.  Must be ≥ 1.
    n_red_divisions:
        Number of Red divisions.  Defaults to *n_divisions*.
    n_red_brigades_per_division:
        Red brigades per division.  Defaults to *n_brigades_per_division*.
    n_red_per_brigade:
        Red battalions per brigade.  Defaults to *n_blue_per_brigade*.
    map_width, map_height:
        Map dimensions in metres (default 10 km × 5 km = 50 km²).
    max_steps:
        Maximum primitive-step episode length.
    road_network:
        Optional :class:`~envs.sim.road_network.RoadNetwork`.  When
        ``None`` (the default) a standard network is generated via
        :meth:`~envs.sim.road_network.RoadNetwork.generate_default`.
    supply_network:
        Optional :class:`~envs.sim.supply_network.SupplyNetwork`.  When
        ``None`` (the default) a standard bilateral supply network is
        generated via
        :meth:`~envs.sim.supply_network.SupplyNetwork.generate_default`.
    objectives:
        Optional list of :class:`OperationalObjective`.  When ``None``,
        three default objectives are placed automatically.
    comm_radius:
        Inter-division communication radius in metres.  Enemy threat
        vectors beyond this distance are replaced with sentinels.
    red_random:
        When ``True`` Red takes random brigade actions.
    randomize_terrain:
        Pass-through to the inner env.
    visibility_radius:
        Fog-of-war visibility radius in metres.
    render_mode:
        ``None`` or ``"human"``.
    """

    metadata: dict = {"render_modes": ["human"], "name": "corps_v0"}

    def __init__(
        self,
        n_divisions: int = 3,
        n_brigades_per_division: int = 3,
        n_blue_per_brigade: int = 4,
        n_red_divisions: Optional[int] = None,
        n_red_brigades_per_division: Optional[int] = None,
        n_red_per_brigade: Optional[int] = None,
        map_width: float = CORPS_MAP_WIDTH,
        map_height: float = CORPS_MAP_HEIGHT,
        max_steps: int = MAX_STEPS,
        road_network: Optional[RoadNetwork] = None,
        supply_network: Optional[SupplyNetwork] = None,
        objectives: Optional[List[OperationalObjective]] = None,
        comm_radius: float = DEFAULT_COMM_RADIUS,
        red_random: bool = True,
        randomize_terrain: bool = True,
        visibility_radius: float = DEFAULT_VISIBILITY_RADIUS,
        render_mode: Optional[str] = None,
    ) -> None:
        # ── Validation ───────────────────────────────────────────────
        if int(n_divisions) < 1:
            raise ValueError(f"n_divisions must be >= 1, got {n_divisions}")
        if int(n_brigades_per_division) < 1:
            raise ValueError(
                f"n_brigades_per_division must be >= 1, got {n_brigades_per_division}"
            )
        if int(n_blue_per_brigade) < 1:
            raise ValueError(
                f"n_blue_per_brigade must be >= 1, got {n_blue_per_brigade}"
            )

        self.n_divisions: int = int(n_divisions)
        self.n_brigades_per_division: int = int(n_brigades_per_division)
        self.n_blue_per_brigade: int = int(n_blue_per_brigade)

        self.n_red_divisions: int = int(
            n_divisions if n_red_divisions is None else n_red_divisions
        )
        self.n_red_brigades_per_division: int = int(
            n_brigades_per_division
            if n_red_brigades_per_division is None
            else n_red_brigades_per_division
        )
        self.n_red_per_brigade: int = int(
            n_blue_per_brigade if n_red_per_brigade is None else n_red_per_brigade
        )

        if self.n_red_divisions < 1:
            raise ValueError(
                f"n_red_divisions must be >= 1, got {self.n_red_divisions}"
            )
        if self.n_red_brigades_per_division < 1:
            raise ValueError(
                f"n_red_brigades_per_division must be >= 1, "
                f"got {self.n_red_brigades_per_division}"
            )
        if self.n_red_per_brigade < 1:
            raise ValueError(
                f"n_red_per_brigade must be >= 1, got {self.n_red_per_brigade}"
            )

        # Derived counts
        self.n_blue_brigades: int = self.n_divisions * self.n_brigades_per_division
        self.n_red_brigades: int = (
            self.n_red_divisions * self.n_red_brigades_per_division
        )
        self.n_blue: int = self.n_blue_brigades * self.n_blue_per_brigade
        self.n_red: int = self.n_red_brigades * self.n_red_per_brigade

        self.map_width = float(map_width)
        self.map_height = float(map_height)
        if self.map_width <= 0.0 or self.map_height <= 0.0:
            raise ValueError(
                f"map_width and map_height must both be > 0, "
                f"got map_width={self.map_width}, map_height={self.map_height}"
            )
        self.map_diagonal = math.hypot(self.map_width, self.map_height)
        self.max_steps = int(max_steps)
        if self.max_steps < 1:
            raise ValueError(f"max_steps must be >= 1, got {self.max_steps}")
        self.comm_radius = float(comm_radius)
        self.render_mode = render_mode

        # ── Road network ─────────────────────────────────────────────
        self.road_network: RoadNetwork = (
            road_network
            if road_network is not None
            else RoadNetwork.generate_default(self.map_width, self.map_height)
        )

        # ── Supply network ────────────────────────────────────────────
        self.supply_network: SupplyNetwork = (
            supply_network
            if supply_network is not None
            else SupplyNetwork.generate_default(self.map_width, self.map_height)
        )

        # ── Operational objectives ───────────────────────────────────
        self.objectives: List[OperationalObjective] = (
            objectives
            if objectives is not None
            else self._default_objectives()
        )

        # ── Inner DivisionEnv ────────────────────────────────────────
        self._division = DivisionEnv(
            n_brigades=self.n_blue_brigades,
            n_blue_per_brigade=self.n_blue_per_brigade,
            n_red_brigades=self.n_red_brigades,
            n_red_per_brigade=self.n_red_per_brigade,
            map_width=self.map_width,
            map_height=self.map_height,
            max_steps=self.max_steps,
            red_random=red_random,
            randomize_terrain=randomize_terrain,
            visibility_radius=visibility_radius,
            render_mode=render_mode,
        )
        # Propagate road network to the innermost battalion simulation
        self._set_inner_road_network()

        # n_corps_options mirrors the division option count
        self.n_corps_options: int = self._division.n_div_options

        # ── Action space ────────────────────────────────────────────
        self.action_space = spaces.MultiDiscrete(
            [self.n_corps_options] * self.n_divisions, dtype=np.int64
        )

        # ── Observation space ───────────────────────────────────────
        self._obs_dim: int = _corps_obs_dim(self.n_divisions)
        obs_low, obs_high = self._build_obs_bounds()
        # Store base corps bounds for use in _get_corps_obs() — subclasses
        # may override observation_space, so clipping must use these stored
        # bounds rather than self.observation_space.
        self._corps_obs_low: np.ndarray = obs_low
        self._corps_obs_high: np.ndarray = obs_high
        self.observation_space = spaces.Box(
            low=obs_low, high=obs_high, dtype=np.float32
        )

        # Episode state
        self._corps_steps: int = 0

    # ------------------------------------------------------------------
    # Road network propagation
    # ------------------------------------------------------------------

    def _set_inner_road_network(self) -> None:
        """Attach the road network to the innermost MultiBattalionEnv."""
        inner = self._division._brigade._inner
        inner.road_network = self.road_network

    # ------------------------------------------------------------------
    # Supply network helpers
    # ------------------------------------------------------------------

    def _interdiction_radius(self) -> float:
        """Capture radius used for supply-depot interdiction.

        Matches the radius of the ``CUT_SUPPLY_LINE`` operational objective
        so that any Blue unit that can claim the objective can also capture
        the corresponding depot.
        """
        for obj in self.objectives:
            if obj.obj_type == ObjectiveType.CUT_SUPPLY_LINE:
                return obj.radius
        return min(self.map_width, self.map_height) * 0.05

    def _compute_division_supply_levels(self, inner) -> List[float]:
        """Return the average supply level for each Blue division.

        For each Blue division, computes the mean
        :meth:`~envs.sim.supply_network.SupplyNetwork.get_supply_level`
        across all alive units in that division.  Returns ``0.0`` for a
        division with no alive units.

        Parameters
        ----------
        inner:
            The :class:`~envs.multi_battalion_env.MultiBattalionEnv`
            instance.
        """
        total_per_div = self.n_brigades_per_division * self.n_blue_per_brigade
        levels: List[float] = []
        for i in range(self.n_divisions):
            div_levels: List[float] = []
            for j in range(total_per_div):
                agent_id = f"blue_{i * total_per_div + j}"
                if agent_id in inner._battalions and agent_id in inner._alive:
                    b = inner._battalions[agent_id]
                    div_levels.append(
                        self.supply_network.get_supply_level(b.x, b.y, team=0)
                    )
            levels.append(float(np.mean(div_levels)) if div_levels else 0.0)
        return levels

    # ------------------------------------------------------------------
    # Default objectives
    # ------------------------------------------------------------------

    def _default_objectives(self) -> List[OperationalObjective]:
        """Return three default operational objectives for this map."""
        capture_radius = min(self.map_width, self.map_height) * 0.05
        return [
            OperationalObjective(
                x=self.map_width * 0.5,
                y=self.map_height * 0.5,
                radius=capture_radius,
                obj_type=ObjectiveType.CAPTURE_HEX,
            ),
            OperationalObjective(
                x=self.map_width * 0.8,
                y=self.map_height * 0.5,
                radius=capture_radius,
                obj_type=ObjectiveType.CUT_SUPPLY_LINE,
            ),
            OperationalObjective(
                x=self.map_width * 0.5,
                y=self.map_height * 0.5,
                radius=self.map_height * 0.5,  # covers whole height
                obj_type=ObjectiveType.FIX_AND_FLANK,
            ),
        ]

    # ------------------------------------------------------------------
    # Observation bounds
    # ------------------------------------------------------------------

    def _build_obs_bounds(self) -> tuple[np.ndarray, np.ndarray]:
        lows: list[float] = []
        highs: list[float] = []

        # Corps sector control: [0, 1] × N_CORPS_SECTORS
        lows.extend([0.0] * N_CORPS_SECTORS)
        highs.extend([1.0] * N_CORPS_SECTORS)

        # Per-division status: [avg_strength, avg_morale, alive_ratio]
        for _ in range(self.n_divisions):
            lows.extend([0.0, 0.0, 0.0])
            highs.extend([1.0, 1.0, 1.0])

        # Per-division threat vector: [dist, cos, sin, e_str, e_mor]
        for _ in range(self.n_divisions):
            lows.extend([0.0, -1.0, -1.0, 0.0, 0.0])
            highs.extend([1.0, 1.0, 1.0, 1.0, 1.0])

        # Road usage: [blue_fraction, red_fraction]
        lows.extend([0.0, 0.0])
        highs.extend([1.0, 1.0])

        # Objective control: [0, 1] × N_OBJECTIVES
        lows.extend([0.0] * N_OBJECTIVES)
        highs.extend([1.0] * N_OBJECTIVES)

        # Supply level per Blue division: [0, 1]
        lows.extend([0.0] * self.n_divisions)
        highs.extend([1.0] * self.n_divisions)

        # Step progress: [0, 1]
        lows.append(0.0)
        highs.append(1.0)

        return np.array(lows, dtype=np.float32), np.array(highs, dtype=np.float32)

    # ------------------------------------------------------------------
    # Gymnasium API: reset
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ) -> tuple[np.ndarray, dict]:
        """Reset the environment and return the initial corps observation.

        Parameters
        ----------
        seed:
            RNG seed forwarded to the inner :class:`~envs.division_env.DivisionEnv`.
        options:
            Unused; present for Gymnasium API compatibility.
        """
        if seed is not None:
            super().reset(seed=seed)

        self._division.reset(seed=seed, options=options)
        # Re-attach road network after inner reset (terrain is regenerated)
        self._set_inner_road_network()

        # Reset objective state
        for obj in self.objectives:
            obj.reset()

        # Reset supply network
        self.supply_network.reset()

        self._corps_steps = 0
        return self._get_corps_obs(), {}

    # ------------------------------------------------------------------
    # Gymnasium API: step
    # ------------------------------------------------------------------

    def step(
        self,
        corps_action: np.ndarray,
    ) -> tuple[np.ndarray, float, bool, bool, dict]:
        """Execute one corps macro-step.

        Translates the per-division operational command into a division-level
        action and delegates to the inner :class:`~envs.division_env.DivisionEnv`.

        Parameters
        ----------
        corps_action:
            Integer array of shape ``(n_divisions,)`` with operational
            command indices in ``[0, n_corps_options)``.

        Returns
        -------
        obs : np.ndarray — corps observation after the macro-step
        reward : float — pass-through from DivisionEnv plus objective bonuses
        terminated : bool
        truncated : bool
        info : dict — includes ``corps_steps``, ``division_action``,
            ``objective_rewards``, plus all fields from DivisionEnv
        """
        corps_action = np.asarray(corps_action, dtype=np.int64)

        if corps_action.shape != (self.n_divisions,):
            raise ValueError(
                f"corps_action has shape {corps_action.shape!r}, "
                f"expected ({self.n_divisions},)."
            )

        for i, cmd in enumerate(corps_action):
            if int(cmd) < 0 or int(cmd) >= self.n_corps_options:
                raise ValueError(
                    f"Invalid corps command {int(cmd)!r} for division {i}; "
                    f"expected integer in [0, {self.n_corps_options - 1}]."
                )

        # Translate corps commands → DivisionEnv action
        division_action = self._translate_corps_action(corps_action)

        # Delegate to DivisionEnv
        _div_obs, base_reward, terminated, truncated, div_info = self._division.step(
            division_action
        )

        # Update objective states
        inner = self._division._brigade._inner
        for obj in self.objectives:
            if obj.obj_type != ObjectiveType.FIX_AND_FLANK:
                obj.update(inner)

        # ── Supply network step ───────────────────────────────────────
        # Collect alive unit positions for both teams
        blue_positions = []
        red_positions = []
        for agent_id, b in inner._battalions.items():
            if agent_id not in inner._alive:
                continue
            if agent_id.startswith("blue_"):
                blue_positions.append((b.x, b.y))
            else:
                red_positions.append((b.x, b.y))

        # Check supply-line interdiction BEFORE step() so that captured depots
        # are already dead when consumption and convoy transfers are computed
        # for this tick — ensuring immediate effect on the same step.
        capture_radius = self._interdiction_radius()
        for bx, by in blue_positions:
            self.supply_network.interdict_nearest_depot(
                bx, by, enemy_team=1, capture_radius=capture_radius
            )

        self.supply_network.step(blue_positions, red_positions)

        # Compute supply levels per Blue division for the info dict
        supply_levels = self._compute_division_supply_levels(inner)

        # Compute operational objective rewards
        obj_reward, obj_details = self._compute_objective_rewards(inner)
        total_reward = float(base_reward) + obj_reward

        self._corps_steps += 1

        # Count alive units for operational casualty tracking.
        blue_alive = sum(
            1 for aid in inner._battalions
            if aid in inner._alive and aid.startswith("blue_")
        )
        red_alive = sum(
            1 for aid in inner._battalions
            if aid in inner._alive and not aid.startswith("blue_")
        )

        info: dict = {
            "corps_steps": self._corps_steps,
            "division_action": division_action.tolist(),
            "objective_rewards": obj_details,
            "supply_levels": supply_levels,
            "blue_units_alive": blue_alive,
            "red_units_alive": red_alive,
        }
        info.update(div_info)

        return self._get_corps_obs(), total_reward, terminated, truncated, info

    # ------------------------------------------------------------------
    # Command translation
    # ------------------------------------------------------------------

    def _translate_corps_action(self, corps_action: np.ndarray) -> np.ndarray:
        """Expand a per-division command into a flat per-brigade division action.

        Division *i* covers brigades
        ``[i*n_brigades_per_division : (i+1)*n_brigades_per_division]``.

        Parameters
        ----------
        corps_action:
            Integer array of shape ``(n_divisions,)`` with command indices.

        Returns
        -------
        np.ndarray of shape ``(n_blue_brigades,)`` — option index per brigade.
        """
        division_action = np.empty(self.n_blue_brigades, dtype=np.int64)
        for i in range(self.n_divisions):
            start = i * self.n_brigades_per_division
            end = start + self.n_brigades_per_division
            division_action[start:end] = int(corps_action[i])
        return division_action

    # ------------------------------------------------------------------
    # Objective rewards
    # ------------------------------------------------------------------

    def _compute_objective_rewards(
        self, inner
    ) -> tuple[float, dict]:
        """Compute per-objective reward bonuses for the current step.

        Always returns a stable dict with all three canonical keys
        (``capture_hex``, ``cut_supply_line``, ``fix_and_flank``),
        defaulting to 0.0 for objective types not in ``self.objectives``.

        Returns
        -------
        total_bonus : float
        details : dict  mapping objective name → reward granted this step
        """
        bonus = 0.0
        details: dict[str, float] = {
            "capture_hex": 0.0,
            "cut_supply_line": 0.0,
            "fix_and_flank": 0.0,
        }

        for obj in self.objectives:
            if obj.obj_type == ObjectiveType.CAPTURE_HEX:
                if obj.is_blue_controlled:
                    details["capture_hex"] = _OBJ_CAPTURE_REWARD
                    bonus += _OBJ_CAPTURE_REWARD

            elif obj.obj_type == ObjectiveType.CUT_SUPPLY_LINE:
                if obj.is_blue_controlled:
                    details["cut_supply_line"] = _OBJ_SUPPLY_CUT_REWARD
                    bonus += _OBJ_SUPPLY_CUT_REWARD

            elif obj.obj_type == ObjectiveType.FIX_AND_FLANK:
                active = _detect_fix_and_flank(inner, self.map_height)
                if active:
                    details["fix_and_flank"] = _OBJ_FIX_FLANK_REWARD
                    bonus += _OBJ_FIX_FLANK_REWARD

        return bonus, details

    # ------------------------------------------------------------------
    # Fog-of-war radius hook
    # ------------------------------------------------------------------

    def _get_fog_radius(self) -> float:
        """Return the effective fog-of-war comm radius for threat vectors.

        Subclasses may override this method to implement cavalry
        reconnaissance or other intelligence assets that extend the
        effective communication range for threat vector building.
        """
        return self.comm_radius

    # ------------------------------------------------------------------
    # Corps observation construction
    # ------------------------------------------------------------------

    def _get_corps_obs(self) -> np.ndarray:
        """Build and return the normalised corps observation vector."""
        parts: list[float] = []
        inner = self._division._brigade._inner

        # ── 1. Corps sector control (7 vertical strips) ───────────────
        sector_width = self.map_width / N_CORPS_SECTORS
        for s in range(N_CORPS_SECTORS):
            x_lo = s * sector_width
            x_hi = (s + 1) * sector_width
            blue_str = 0.0
            red_str = 0.0
            for agent_id, b in inner._battalions.items():
                if agent_id not in inner._alive:
                    continue
                in_sector = (x_lo <= b.x < x_hi) or (
                    s == N_CORPS_SECTORS - 1 and b.x == self.map_width
                )
                if in_sector:
                    if agent_id.startswith("blue_"):
                        blue_str += float(b.strength)
                    else:
                        red_str += float(b.strength)
            total = blue_str + red_str
            parts.append(blue_str / total if total > 0.0 else 0.5)

        # ── 2. Per-division status [avg_strength, avg_morale, alive_ratio] ──
        for i in range(self.n_divisions):
            strengths: list[float] = []
            morales: list[float] = []
            alive_count = 0
            total_per_div = self.n_brigades_per_division * self.n_blue_per_brigade
            for j in range(total_per_div):
                agent_id = f"blue_{i * total_per_div + j}"
                if agent_id in inner._battalions and agent_id in inner._alive:
                    b = inner._battalions[agent_id]
                    strengths.append(float(b.strength))
                    morales.append(float(b.morale))
                    alive_count += 1
            avg_str = float(np.mean(strengths)) if strengths else 0.0
            avg_mor = float(np.mean(morales)) if morales else 0.0
            alive_ratio = alive_count / total_per_div
            parts.extend([avg_str, avg_mor, alive_ratio])

        # ── 3. Per-division threat vector ──────────────────────────────
        red_div_centroids = self._get_red_division_centroids(inner)
        blue_div_centroids = self._get_blue_division_centroids(inner)

        for i in range(self.n_divisions):
            cx, cy = blue_div_centroids[i] if blue_div_centroids[i] else (None, None)

            if cx is None or not red_div_centroids:
                # This division is dead or no Red divisions alive — sentinel
                parts.extend([1.0, 0.0, 0.0, 0.0, 0.0])
                continue

            # Find nearest Red division centroid
            best_dist = float("inf")
            best_centroid = None
            best_e_str = 0.0
            best_e_mor = 0.0
            for (rx, ry, e_str, e_mor) in red_div_centroids:
                dx = rx - cx
                dy = ry - cy
                d = math.sqrt(dx * dx + dy * dy)
                if d < best_dist:
                    best_dist = d
                    best_centroid = (rx, ry)
                    best_e_str = e_str
                    best_e_mor = e_mor

            assert best_centroid is not None

            # Apply comm_radius gating (subclasses may override via _get_fog_radius)
            if best_dist > self._get_fog_radius():
                # Enemy beyond communication range — sentinel
                parts.extend([1.0, 0.0, 0.0, 0.0, 0.0])
                continue

            dx = best_centroid[0] - cx
            dy = best_centroid[1] - cy
            bearing = math.atan2(dy, dx)
            parts.append(min(best_dist / self.map_diagonal, 1.0))
            parts.append(math.cos(bearing))
            parts.append(math.sin(bearing))
            parts.append(best_e_str)
            parts.append(best_e_mor)

        # ── 4. Road usage ─────────────────────────────────────────────
        blue_positions: list[tuple[float, float]] = []
        red_positions: list[tuple[float, float]] = []
        for agent_id, b in inner._battalions.items():
            if agent_id not in inner._alive:
                continue
            if agent_id.startswith("blue_"):
                blue_positions.append((b.x, b.y))
            else:
                red_positions.append((b.x, b.y))
        parts.append(self.road_network.fraction_on_road(blue_positions))
        parts.append(self.road_network.fraction_on_road(red_positions))

        # ── 5. Objective control ──────────────────────────────────────
        # Always emit exactly N_OBJECTIVES (3) slots for a stable obs schema.
        # Aggregate by type: last objective of each type wins; missing → 0.5/0.0.
        obj_values: dict[ObjectiveType, float] = {}
        fix_flank_active = _detect_fix_and_flank(inner, self.map_height)
        for obj in self.objectives:
            if obj.obj_type == ObjectiveType.FIX_AND_FLANK:
                obj_values[ObjectiveType.FIX_AND_FLANK] = 1.0 if fix_flank_active else 0.0
            else:
                obj_values[obj.obj_type] = float(obj.control_value)

        parts.append(obj_values.get(ObjectiveType.CAPTURE_HEX, 0.5))
        parts.append(obj_values.get(ObjectiveType.CUT_SUPPLY_LINE, 0.5))
        parts.append(obj_values.get(ObjectiveType.FIX_AND_FLANK, 0.0))

        # ── 6. Supply level per Blue division ─────────────────────────
        supply_levels = self._compute_division_supply_levels(inner)
        parts.extend(supply_levels)

        # ── 7. Step progress ──────────────────────────────────────────
        parts.append(min(inner._step_count / self.max_steps, 1.0))

        obs = np.array(parts, dtype=np.float32)
        return np.clip(obs, self._corps_obs_low, self._corps_obs_high)

    # ------------------------------------------------------------------
    # Division centroid helpers
    # ------------------------------------------------------------------

    def _get_blue_division_centroids(
        self, inner
    ) -> list[Optional[tuple[float, float]]]:
        """Return ``(cx, cy)`` centroid for each Blue division, or ``None`` if dead."""
        centroids: list[Optional[tuple[float, float]]] = []
        total_per_div = self.n_brigades_per_division * self.n_blue_per_brigade
        for i in range(self.n_divisions):
            xs: list[float] = []
            ys: list[float] = []
            for j in range(total_per_div):
                agent_id = f"blue_{i * total_per_div + j}"
                if agent_id in inner._battalions and agent_id in inner._alive:
                    b = inner._battalions[agent_id]
                    xs.append(b.x)
                    ys.append(b.y)
            if xs:
                centroids.append((float(np.mean(xs)), float(np.mean(ys))))
            else:
                centroids.append(None)
        return centroids

    def _get_red_division_centroids(
        self, inner
    ) -> list[tuple[float, float, float, float]]:
        """Return ``(cx, cy, avg_strength, avg_morale)`` for each alive Red division."""
        centroids = []
        total_per_div = self.n_red_brigades_per_division * self.n_red_per_brigade
        for i in range(self.n_red_divisions):
            xs: list[float] = []
            ys: list[float] = []
            strs: list[float] = []
            mors: list[float] = []
            for j in range(total_per_div):
                agent_id = f"red_{i * total_per_div + j}"
                if agent_id in inner._battalions and agent_id in inner._alive:
                    b = inner._battalions[agent_id]
                    xs.append(b.x)
                    ys.append(b.y)
                    strs.append(float(b.strength))
                    mors.append(float(b.morale))
            if xs:
                centroids.append((
                    float(np.mean(xs)),
                    float(np.mean(ys)),
                    float(np.mean(strs)),
                    float(np.mean(mors)),
                ))
        return centroids

    # ------------------------------------------------------------------
    # Gymnasium API: close / render
    # ------------------------------------------------------------------

    def close(self) -> None:
        """Clean up resources."""
        self._division.close()

    def render(self) -> None:
        """Rendering is not yet implemented."""

close()

Clean up resources.

Source code in envs/corps_env.py
def close(self) -> None:
    """Clean up resources."""
    self._division.close()

render()

Rendering is not yet implemented.

Source code in envs/corps_env.py
def render(self) -> None:
    """Rendering is not yet implemented."""

reset(seed=None, options=None)

Reset the environment and return the initial corps observation.

Parameters:

Name Type Description Default
seed Optional[int]

RNG seed forwarded to the inner :class:~envs.division_env.DivisionEnv.

None
options Optional[dict]

Unused; present for Gymnasium API compatibility.

None
Source code in envs/corps_env.py
def reset(
    self,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> tuple[np.ndarray, dict]:
    """Reset the environment and return the initial corps observation.

    Parameters
    ----------
    seed:
        RNG seed forwarded to the inner :class:`~envs.division_env.DivisionEnv`.
    options:
        Unused; present for Gymnasium API compatibility.
    """
    if seed is not None:
        super().reset(seed=seed)

    self._division.reset(seed=seed, options=options)
    # Re-attach road network after inner reset (terrain is regenerated)
    self._set_inner_road_network()

    # Reset objective state
    for obj in self.objectives:
        obj.reset()

    # Reset supply network
    self.supply_network.reset()

    self._corps_steps = 0
    return self._get_corps_obs(), {}

step(corps_action)

Execute one corps macro-step.

Translates the per-division operational command into a division-level action and delegates to the inner :class:~envs.division_env.DivisionEnv.

Parameters:

Name Type Description Default
corps_action ndarray

Integer array of shape (n_divisions,) with operational command indices in [0, n_corps_options).

required

Returns:

Name Type Description
obs np.ndarray — corps observation after the macro-step
reward float — pass-through from DivisionEnv plus objective bonuses
terminated bool
truncated bool
info dict — includes ``corps_steps``, ``division_action``,

objective_rewards, plus all fields from DivisionEnv

Source code in envs/corps_env.py
def step(
    self,
    corps_action: np.ndarray,
) -> tuple[np.ndarray, float, bool, bool, dict]:
    """Execute one corps macro-step.

    Translates the per-division operational command into a division-level
    action and delegates to the inner :class:`~envs.division_env.DivisionEnv`.

    Parameters
    ----------
    corps_action:
        Integer array of shape ``(n_divisions,)`` with operational
        command indices in ``[0, n_corps_options)``.

    Returns
    -------
    obs : np.ndarray — corps observation after the macro-step
    reward : float — pass-through from DivisionEnv plus objective bonuses
    terminated : bool
    truncated : bool
    info : dict — includes ``corps_steps``, ``division_action``,
        ``objective_rewards``, plus all fields from DivisionEnv
    """
    corps_action = np.asarray(corps_action, dtype=np.int64)

    if corps_action.shape != (self.n_divisions,):
        raise ValueError(
            f"corps_action has shape {corps_action.shape!r}, "
            f"expected ({self.n_divisions},)."
        )

    for i, cmd in enumerate(corps_action):
        if int(cmd) < 0 or int(cmd) >= self.n_corps_options:
            raise ValueError(
                f"Invalid corps command {int(cmd)!r} for division {i}; "
                f"expected integer in [0, {self.n_corps_options - 1}]."
            )

    # Translate corps commands → DivisionEnv action
    division_action = self._translate_corps_action(corps_action)

    # Delegate to DivisionEnv
    _div_obs, base_reward, terminated, truncated, div_info = self._division.step(
        division_action
    )

    # Update objective states
    inner = self._division._brigade._inner
    for obj in self.objectives:
        if obj.obj_type != ObjectiveType.FIX_AND_FLANK:
            obj.update(inner)

    # ── Supply network step ───────────────────────────────────────
    # Collect alive unit positions for both teams
    blue_positions = []
    red_positions = []
    for agent_id, b in inner._battalions.items():
        if agent_id not in inner._alive:
            continue
        if agent_id.startswith("blue_"):
            blue_positions.append((b.x, b.y))
        else:
            red_positions.append((b.x, b.y))

    # Check supply-line interdiction BEFORE step() so that captured depots
    # are already dead when consumption and convoy transfers are computed
    # for this tick — ensuring immediate effect on the same step.
    capture_radius = self._interdiction_radius()
    for bx, by in blue_positions:
        self.supply_network.interdict_nearest_depot(
            bx, by, enemy_team=1, capture_radius=capture_radius
        )

    self.supply_network.step(blue_positions, red_positions)

    # Compute supply levels per Blue division for the info dict
    supply_levels = self._compute_division_supply_levels(inner)

    # Compute operational objective rewards
    obj_reward, obj_details = self._compute_objective_rewards(inner)
    total_reward = float(base_reward) + obj_reward

    self._corps_steps += 1

    # Count alive units for operational casualty tracking.
    blue_alive = sum(
        1 for aid in inner._battalions
        if aid in inner._alive and aid.startswith("blue_")
    )
    red_alive = sum(
        1 for aid in inner._battalions
        if aid in inner._alive and not aid.startswith("blue_")
    )

    info: dict = {
        "corps_steps": self._corps_steps,
        "division_action": division_action.tolist(),
        "objective_rewards": obj_details,
        "supply_levels": supply_levels,
        "blue_units_alive": blue_alive,
        "red_units_alive": red_alive,
    }
    info.update(div_info)

    return self._get_corps_obs(), total_reward, terminated, truncated, info

envs.cavalry_corps_env.CavalryCorpsEnv

Bases: CorpsEnv

Gymnasium environment for a corps commander with an independent cavalry arm.

Extends :class:~envs.corps_env.CorpsEnv with cavalry brigades that execute reconnaissance, raiding, and pursuit missions each step. Cavalry intelligence reduces the fog-of-war for allied divisions when units are on RECONNAISSANCE mission.

Parameters:

Name Type Description Default
n_divisions int

Number of Blue infantry divisions.

3
n_brigades_per_division int

Blue infantry brigades per division.

3
n_blue_per_brigade int

Blue battalions per brigade.

4
n_red_divisions Optional[int]

Red force composition (mirrors Blue by default).

None
n_red_brigades_per_division Optional[int]

Red force composition (mirrors Blue by default).

None
n_red_per_brigade Optional[int]

Red force composition (mirrors Blue by default).

None
map_width float

Map dimensions in metres.

CORPS_MAP_WIDTH
map_height float

Map dimensions in metres.

CORPS_MAP_WIDTH
max_steps int

Episode length cap.

MAX_STEPS
road_network Optional[RoadNetwork]

Optional overrides passed through to :class:~envs.corps_env.CorpsEnv.

None
supply_network Optional[RoadNetwork]

Optional overrides passed through to :class:~envs.corps_env.CorpsEnv.

None
objectives Optional[RoadNetwork]

Optional overrides passed through to :class:~envs.corps_env.CorpsEnv.

None
comm_radius float

Base inter-division communication radius. Overridden to when cavalry recon reveals enemy positions.

3000.0
red_random bool

When True Red takes random brigade actions.

True
randomize_terrain bool

Pass-through to inner env.

True
visibility_radius float

Fog-of-war visibility radius for the inner simulation.

1500.0
render_mode Optional[str]

None or "human".

None
n_cavalry_brigades int

Number of Blue cavalry brigades (default 2).

2
cavalry_corps Optional[CavalryCorps]

Optional pre-built :class:~envs.sim.cavalry_corps.CavalryCorps. If None, a default corps is generated at construction time.

None
cav_config Optional[CavalryUnitConfig]

Optional :class:~envs.sim.cavalry_corps.CavalryUnitConfig used when generating the default cavalry corps.

None
Source code in envs/cavalry_corps_env.py
class CavalryCorpsEnv(CorpsEnv):
    """Gymnasium environment for a corps commander with an independent cavalry arm.

    Extends :class:`~envs.corps_env.CorpsEnv` with cavalry brigades that
    execute reconnaissance, raiding, and pursuit missions each step.
    Cavalry intelligence reduces the fog-of-war for allied divisions
    when units are on RECONNAISSANCE mission.

    Parameters
    ----------
    n_divisions:
        Number of Blue infantry divisions.
    n_brigades_per_division:
        Blue infantry brigades per division.
    n_blue_per_brigade:
        Blue battalions per brigade.
    n_red_divisions, n_red_brigades_per_division, n_red_per_brigade:
        Red force composition (mirrors Blue by default).
    map_width, map_height:
        Map dimensions in metres.
    max_steps:
        Episode length cap.
    road_network, supply_network, objectives:
        Optional overrides passed through to :class:`~envs.corps_env.CorpsEnv`.
    comm_radius:
        Base inter-division communication radius.  Overridden to ``∞``
        when cavalry recon reveals enemy positions.
    red_random:
        When ``True`` Red takes random brigade actions.
    randomize_terrain:
        Pass-through to inner env.
    visibility_radius:
        Fog-of-war visibility radius for the inner simulation.
    render_mode:
        ``None`` or ``"human"``.
    n_cavalry_brigades:
        Number of Blue cavalry brigades (default ``2``).
    cavalry_corps:
        Optional pre-built :class:`~envs.sim.cavalry_corps.CavalryCorps`.
        If ``None``, a default corps is generated at construction time.
    cav_config:
        Optional :class:`~envs.sim.cavalry_corps.CavalryUnitConfig` used
        when generating the default cavalry corps.
    """

    metadata: dict = {"render_modes": ["human"], "name": "cavalry_corps_v0"}

    def __init__(
        self,
        n_divisions: int = 3,
        n_brigades_per_division: int = 3,
        n_blue_per_brigade: int = 4,
        n_red_divisions: Optional[int] = None,
        n_red_brigades_per_division: Optional[int] = None,
        n_red_per_brigade: Optional[int] = None,
        map_width: float = CORPS_MAP_WIDTH,
        map_height: float = CORPS_MAP_HEIGHT,
        max_steps: int = MAX_STEPS,
        road_network: Optional[RoadNetwork] = None,
        supply_network: Optional[SupplyNetwork] = None,
        objectives=None,
        comm_radius: float = 3_000.0,
        red_random: bool = True,
        randomize_terrain: bool = True,
        visibility_radius: float = 1_500.0,
        render_mode: Optional[str] = None,
        n_cavalry_brigades: int = 2,
        cavalry_corps: Optional[CavalryCorps] = None,
        cav_config: Optional[CavalryUnitConfig] = None,
    ) -> None:
        # ── Build base CorpsEnv ──────────────────────────────────────
        super().__init__(
            n_divisions=n_divisions,
            n_brigades_per_division=n_brigades_per_division,
            n_blue_per_brigade=n_blue_per_brigade,
            n_red_divisions=n_red_divisions,
            n_red_brigades_per_division=n_red_brigades_per_division,
            n_red_per_brigade=n_red_per_brigade,
            map_width=map_width,
            map_height=map_height,
            max_steps=max_steps,
            road_network=road_network,
            supply_network=supply_network,
            objectives=objectives,
            comm_radius=comm_radius,
            red_random=red_random,
            randomize_terrain=randomize_terrain,
            visibility_radius=visibility_radius,
            render_mode=render_mode,
        )

        # ── Cavalry configuration ────────────────────────────────────
        if cavalry_corps is not None:
            # Derive n_cavalry_brigades from the provided corps to ensure
            # the action/observation spaces are consistent with the sim state.
            try:
                n_cav_from_corps = len(cavalry_corps.units)
            except AttributeError as exc:
                raise ValueError(
                    "Provided cavalry_corps must be a valid CavalryCorps instance "
                    "with a 'units' attribute."
                ) from exc
            if n_cav_from_corps < 1:
                raise ValueError(
                    "Provided cavalry_corps must contain at least one cavalry unit."
                )
            self.n_cavalry_brigades: int = int(n_cav_from_corps)
            self._cavalry: CavalryCorps = cavalry_corps
        else:
            if int(n_cavalry_brigades) < 1:
                raise ValueError(
                    f"n_cavalry_brigades must be >= 1, got {n_cavalry_brigades!r}"
                )
            self.n_cavalry_brigades: int = int(n_cavalry_brigades)
            self._cavalry: CavalryCorps = CavalryCorps.generate_default(
                map_width=self.map_width,
                map_height=self.map_height,
                n_brigades=self.n_cavalry_brigades,
                team=0,
                config=cav_config,
            )
        self._last_cav_report: CavalryReport = CavalryReport([], 0, 0, 0.0)

        # ── Override action space ────────────────────────────────────
        # Corps commands (n_divisions) + cavalry missions (n_cavalry_brigades)
        self.action_space = spaces.MultiDiscrete(
            [self.n_corps_options] * self.n_divisions
            + [N_CAVALRY_MISSIONS] * self.n_cavalry_brigades,
            dtype=np.int64,
        )

        # ── Override observation space ───────────────────────────────
        obs_low, obs_high = self._build_cav_obs_bounds()
        self.observation_space = spaces.Box(
            low=obs_low, high=obs_high, dtype=np.float32
        )

    # ------------------------------------------------------------------
    # Fog-of-war hook (overrides CorpsEnv._get_fog_radius)
    # ------------------------------------------------------------------

    def _get_fog_radius(self) -> float:
        """Lift the comm_radius gating when cavalry recon has revealed enemies.

        When at least one enemy unit has been spotted by a RECONNAISSANCE
        cavalry brigade, the effective fog radius is set to ``math.inf`` —
        all allied divisions receive accurate threat vectors rather than
        sentinels.  When no enemies are revealed, the base ``comm_radius``
        is used.
        """
        if self._last_cav_report.revealed_enemy_positions:
            return math.inf
        return self.comm_radius

    # ------------------------------------------------------------------
    # Gymnasium API: reset
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ) -> tuple[np.ndarray, dict]:
        """Reset the environment and return the initial cavalry corps observation.

        Calls :meth:`~envs.corps_env.CorpsEnv.reset` on the base env,
        resets cavalry unit positions, and returns the extended observation.
        """
        # Reset base corps env (also resets inner env, road/supply networks)
        _, info = super().reset(seed=seed, options=options)

        # Reset cavalry positions to default spread
        for i, unit in enumerate(self._cavalry.units[: self.n_cavalry_brigades]):
            unit.x = self.map_width * 0.2
            unit.y = self.map_height * (i + 1) / (self.n_cavalry_brigades + 1)
            unit.theta = 0.0
            unit.strength = 1.0
            unit.alive = True
            unit.mission = CavalryMission.IDLE

        self._last_cav_report = CavalryReport([], 0, 0, 0.0)
        return self._build_cav_obs(), info

    # ------------------------------------------------------------------
    # Gymnasium API: step
    # ------------------------------------------------------------------

    def step(
        self, action: np.ndarray
    ) -> tuple[np.ndarray, float, bool, bool, dict]:
        """Execute one cavalry-corps macro-step.

        Splits the combined action into corps commands and cavalry mission
        assignments.  Infantry and supply logic executes first (via the
        base :class:`~envs.corps_env.CorpsEnv`); cavalry then acts on the
        updated battlefield state.

        Parameters
        ----------
        action:
            Integer array of shape
            ``(n_divisions + n_cavalry_brigades,)``.
            Elements ``[:n_divisions]`` are standard corps operational
            commands.  Elements ``[n_divisions:]`` assign a
            :class:`~envs.sim.cavalry_corps.CavalryMission` to each
            cavalry brigade.

        Returns
        -------
        obs : np.ndarray — extended cavalry corps observation
        reward : float — base corps reward (cavalry adds no extra reward)
        terminated : bool
        truncated : bool
        info : dict — base corps info plus ``"cavalry"`` sub-dict
        """
        action = np.asarray(action, dtype=np.int64)
        expected_len = self.n_divisions + self.n_cavalry_brigades
        if action.shape != (expected_len,):
            raise ValueError(
                f"action has shape {action.shape!r}, "
                f"expected ({expected_len},)."
            )

        # Split action
        corps_action = action[: self.n_divisions]
        cav_action = action[self.n_divisions :]

        # Assign cavalry missions for this step
        for i, unit in enumerate(self._cavalry.units[: self.n_cavalry_brigades]):
            if unit.alive:
                unit.mission = CavalryMission(int(cav_action[i]))

        # ── Execute base corps step (infantry + supply) ───────────────
        _, base_reward, terminated, truncated, info = super().step(corps_action)

        # ── Execute cavalry step on updated battlefield state ─────────
        inner = self._division._brigade._inner
        self._last_cav_report = self._cavalry.step(inner, self.supply_network)

        # ── Augment info with cavalry report ─────────────────────────
        info["cavalry"] = {
            "depots_raided": self._last_cav_report.depots_raided,
            "routed_units_pursued": self._last_cav_report.routed_units_pursued,
            "pursuit_damage": self._last_cav_report.pursuit_damage,
            "n_revealed_enemies": len(
                self._last_cav_report.revealed_enemy_positions
            ),
        }

        return self._build_cav_obs(), base_reward, terminated, truncated, info

    # ------------------------------------------------------------------
    # Observation construction
    # ------------------------------------------------------------------

    def _build_cav_obs(self) -> np.ndarray:
        """Build and return the full cavalry corps observation vector.

        Concatenates:
        1. Base corps observation (with cavalry-enhanced fog of war).
        2. Per-cavalry-brigade state.
        3. Cavalry intelligence summary.
        """
        # Base corps observation uses _get_fog_radius() via _get_corps_obs()
        base_obs: np.ndarray = self._get_corps_obs()

        # ── Per-cavalry-brigade state ─────────────────────────────────
        cav_parts: List[float] = []
        for unit in self._cavalry.units[: self.n_cavalry_brigades]:
            if unit.alive:
                cav_parts.extend(
                    [
                        unit.x / self.map_width,
                        unit.y / self.map_height,
                        float(int(unit.mission)) / (N_CAVALRY_MISSIONS - 1),
                        unit.strength,
                    ]
                )
            else:
                cav_parts.extend([0.0, 0.0, 0.0, 0.0])

        # ── Cavalry intelligence summary ──────────────────────────────
        revealed = self._last_cav_report.revealed_enemy_positions
        n_total = max(
            1,
            self.n_red_divisions
            * self.n_red_brigades_per_division
            * self.n_red_per_brigade,
        )
        n_revealed_norm = min(1.0, len(revealed) / n_total)

        if revealed:
            rx_norm = (
                float(np.mean([p[0] for p in revealed])) / self.map_width
            )
            ry_norm = (
                float(np.mean([p[1] for p in revealed])) / self.map_height
            )
        else:
            rx_norm, ry_norm = 0.0, 0.0

        cav_parts.extend([n_revealed_norm, rx_norm, ry_norm])

        cav_arr = np.array(cav_parts, dtype=np.float32)
        obs = np.concatenate([base_obs, cav_arr])
        return np.clip(obs, self.observation_space.low, self.observation_space.high)

    # ------------------------------------------------------------------
    # Observation bounds
    # ------------------------------------------------------------------

    def _build_cav_obs_bounds(
        self,
    ) -> tuple[np.ndarray, np.ndarray]:
        """Return ``(low, high)`` bounds for the full cavalry corps observation."""
        base_low, base_high = self._build_obs_bounds()

        # Extra dims: all in [0, 1]
        n_extra = self.n_cavalry_brigades * _CAV_UNIT_OBS_DIM + _CAV_INTEL_DIM
        extra_low = np.zeros(n_extra, dtype=np.float32)
        extra_high = np.ones(n_extra, dtype=np.float32)

        return (
            np.concatenate([base_low, extra_low]),
            np.concatenate([base_high, extra_high]),
        )

reset(seed=None, options=None)

Reset the environment and return the initial cavalry corps observation.

Calls :meth:~envs.corps_env.CorpsEnv.reset on the base env, resets cavalry unit positions, and returns the extended observation.

Source code in envs/cavalry_corps_env.py
def reset(
    self,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> tuple[np.ndarray, dict]:
    """Reset the environment and return the initial cavalry corps observation.

    Calls :meth:`~envs.corps_env.CorpsEnv.reset` on the base env,
    resets cavalry unit positions, and returns the extended observation.
    """
    # Reset base corps env (also resets inner env, road/supply networks)
    _, info = super().reset(seed=seed, options=options)

    # Reset cavalry positions to default spread
    for i, unit in enumerate(self._cavalry.units[: self.n_cavalry_brigades]):
        unit.x = self.map_width * 0.2
        unit.y = self.map_height * (i + 1) / (self.n_cavalry_brigades + 1)
        unit.theta = 0.0
        unit.strength = 1.0
        unit.alive = True
        unit.mission = CavalryMission.IDLE

    self._last_cav_report = CavalryReport([], 0, 0, 0.0)
    return self._build_cav_obs(), info

step(action)

Execute one cavalry-corps macro-step.

Splits the combined action into corps commands and cavalry mission assignments. Infantry and supply logic executes first (via the base :class:~envs.corps_env.CorpsEnv); cavalry then acts on the updated battlefield state.

Parameters:

Name Type Description Default
action ndarray

Integer array of shape (n_divisions + n_cavalry_brigades,). Elements [:n_divisions] are standard corps operational commands. Elements [n_divisions:] assign a :class:~envs.sim.cavalry_corps.CavalryMission to each cavalry brigade.

required

Returns:

Name Type Description
obs np.ndarray — extended cavalry corps observation
reward float — base corps reward (cavalry adds no extra reward)
terminated bool
truncated bool
info dict — base corps info plus ``"cavalry"`` sub-dict
Source code in envs/cavalry_corps_env.py
def step(
    self, action: np.ndarray
) -> tuple[np.ndarray, float, bool, bool, dict]:
    """Execute one cavalry-corps macro-step.

    Splits the combined action into corps commands and cavalry mission
    assignments.  Infantry and supply logic executes first (via the
    base :class:`~envs.corps_env.CorpsEnv`); cavalry then acts on the
    updated battlefield state.

    Parameters
    ----------
    action:
        Integer array of shape
        ``(n_divisions + n_cavalry_brigades,)``.
        Elements ``[:n_divisions]`` are standard corps operational
        commands.  Elements ``[n_divisions:]`` assign a
        :class:`~envs.sim.cavalry_corps.CavalryMission` to each
        cavalry brigade.

    Returns
    -------
    obs : np.ndarray — extended cavalry corps observation
    reward : float — base corps reward (cavalry adds no extra reward)
    terminated : bool
    truncated : bool
    info : dict — base corps info plus ``"cavalry"`` sub-dict
    """
    action = np.asarray(action, dtype=np.int64)
    expected_len = self.n_divisions + self.n_cavalry_brigades
    if action.shape != (expected_len,):
        raise ValueError(
            f"action has shape {action.shape!r}, "
            f"expected ({expected_len},)."
        )

    # Split action
    corps_action = action[: self.n_divisions]
    cav_action = action[self.n_divisions :]

    # Assign cavalry missions for this step
    for i, unit in enumerate(self._cavalry.units[: self.n_cavalry_brigades]):
        if unit.alive:
            unit.mission = CavalryMission(int(cav_action[i]))

    # ── Execute base corps step (infantry + supply) ───────────────
    _, base_reward, terminated, truncated, info = super().step(corps_action)

    # ── Execute cavalry step on updated battlefield state ─────────
    inner = self._division._brigade._inner
    self._last_cav_report = self._cavalry.step(inner, self.supply_network)

    # ── Augment info with cavalry report ─────────────────────────
    info["cavalry"] = {
        "depots_raided": self._last_cav_report.depots_raided,
        "routed_units_pursued": self._last_cav_report.routed_units_pursued,
        "pursuit_damage": self._last_cav_report.pursuit_damage,
        "n_revealed_enemies": len(
            self._last_cav_report.revealed_enemy_positions
        ),
    }

    return self._build_cav_obs(), base_reward, terminated, truncated, info

envs.artillery_corps_env.ArtilleryCorpsEnv

Bases: CorpsEnv

Gymnasium environment for a corps commander with an independent artillery arm.

Extends :class:~envs.corps_env.CorpsEnv with artillery batteries that execute grand battery, counter-battery, siege, and fortification missions each step.

Parameters:

Name Type Description Default
n_divisions int

Number of Blue infantry divisions.

3
n_brigades_per_division int

Blue infantry brigades per division.

3
n_blue_per_brigade int

Blue battalions per brigade.

4
n_red_divisions Optional[int]

Red force composition (mirrors Blue by default).

None
n_red_brigades_per_division Optional[int]

Red force composition (mirrors Blue by default).

None
n_red_per_brigade Optional[int]

Red force composition (mirrors Blue by default).

None
map_width float

Map dimensions in metres.

CORPS_MAP_WIDTH
map_height float

Map dimensions in metres.

CORPS_MAP_WIDTH
max_steps int

Episode length cap.

MAX_STEPS
road_network Optional[RoadNetwork]

Optional overrides passed through to :class:~envs.corps_env.CorpsEnv.

None
supply_network Optional[RoadNetwork]

Optional overrides passed through to :class:~envs.corps_env.CorpsEnv.

None
objectives Optional[RoadNetwork]

Optional overrides passed through to :class:~envs.corps_env.CorpsEnv.

None
comm_radius float

Base inter-division communication radius.

3000.0
red_random bool

When True Red takes random brigade actions.

True
randomize_terrain bool

Pass-through to inner env.

True
visibility_radius float

Fog-of-war visibility radius for the inner simulation.

1500.0
render_mode Optional[str]

None or "human".

None
n_artillery_batteries int

Number of Blue artillery batteries (default 4).

4
artillery_corps Optional[ArtilleryCorps]

Optional pre-built :class:~envs.sim.artillery_corps.ArtilleryCorps. If None, a default corps is generated at construction time.

None
art_config Optional[ArtilleryUnitConfig]

Optional :class:~envs.sim.artillery_corps.ArtilleryUnitConfig used when generating the default artillery corps.

None
n_red_artillery_batteries Optional[int]

Number of Red artillery batteries for counter-battery targeting (default equals n_artillery_batteries).

None
Source code in envs/artillery_corps_env.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
class ArtilleryCorpsEnv(CorpsEnv):
    """Gymnasium environment for a corps commander with an independent artillery arm.

    Extends :class:`~envs.corps_env.CorpsEnv` with artillery batteries that
    execute grand battery, counter-battery, siege, and fortification missions
    each step.

    Parameters
    ----------
    n_divisions:
        Number of Blue infantry divisions.
    n_brigades_per_division:
        Blue infantry brigades per division.
    n_blue_per_brigade:
        Blue battalions per brigade.
    n_red_divisions, n_red_brigades_per_division, n_red_per_brigade:
        Red force composition (mirrors Blue by default).
    map_width, map_height:
        Map dimensions in metres.
    max_steps:
        Episode length cap.
    road_network, supply_network, objectives:
        Optional overrides passed through to :class:`~envs.corps_env.CorpsEnv`.
    comm_radius:
        Base inter-division communication radius.
    red_random:
        When ``True`` Red takes random brigade actions.
    randomize_terrain:
        Pass-through to inner env.
    visibility_radius:
        Fog-of-war visibility radius for the inner simulation.
    render_mode:
        ``None`` or ``"human"``.
    n_artillery_batteries:
        Number of Blue artillery batteries (default ``4``).
    artillery_corps:
        Optional pre-built :class:`~envs.sim.artillery_corps.ArtilleryCorps`.
        If ``None``, a default corps is generated at construction time.
    art_config:
        Optional :class:`~envs.sim.artillery_corps.ArtilleryUnitConfig` used
        when generating the default artillery corps.
    n_red_artillery_batteries:
        Number of Red artillery batteries for counter-battery targeting
        (default equals *n_artillery_batteries*).
    """

    metadata: dict = {"render_modes": ["human"], "name": "artillery_corps_v0"}

    def __init__(
        self,
        n_divisions: int = 3,
        n_brigades_per_division: int = 3,
        n_blue_per_brigade: int = 4,
        n_red_divisions: Optional[int] = None,
        n_red_brigades_per_division: Optional[int] = None,
        n_red_per_brigade: Optional[int] = None,
        map_width: float = CORPS_MAP_WIDTH,
        map_height: float = CORPS_MAP_HEIGHT,
        max_steps: int = MAX_STEPS,
        road_network: Optional[RoadNetwork] = None,
        supply_network: Optional[SupplyNetwork] = None,
        objectives=None,
        comm_radius: float = 3_000.0,
        red_random: bool = True,
        randomize_terrain: bool = True,
        visibility_radius: float = 1_500.0,
        render_mode: Optional[str] = None,
        n_artillery_batteries: int = 4,
        artillery_corps: Optional[ArtilleryCorps] = None,
        art_config: Optional[ArtilleryUnitConfig] = None,
        n_red_artillery_batteries: Optional[int] = None,
    ) -> None:
        # ── Build base CorpsEnv ──────────────────────────────────────
        super().__init__(
            n_divisions=n_divisions,
            n_brigades_per_division=n_brigades_per_division,
            n_blue_per_brigade=n_blue_per_brigade,
            n_red_divisions=n_red_divisions,
            n_red_brigades_per_division=n_red_brigades_per_division,
            n_red_per_brigade=n_red_per_brigade,
            map_width=map_width,
            map_height=map_height,
            max_steps=max_steps,
            road_network=road_network,
            supply_network=supply_network,
            objectives=objectives,
            comm_radius=comm_radius,
            red_random=red_random,
            randomize_terrain=randomize_terrain,
            visibility_radius=visibility_radius,
            render_mode=render_mode,
        )

        # ── Blue artillery configuration ─────────────────────────────
        if artillery_corps is not None:
            try:
                n_art_from_corps = len(artillery_corps.units)
            except AttributeError as exc:
                raise ValueError(
                    "Provided artillery_corps must be a valid ArtilleryCorps "
                    "instance with a 'units' attribute."
                ) from exc
            if n_art_from_corps < 1:
                raise ValueError(
                    "Provided artillery_corps must contain at least one battery."
                )
            self.n_artillery_batteries: int = int(n_art_from_corps)
            self._artillery: ArtilleryCorps = artillery_corps
        else:
            if int(n_artillery_batteries) < 1:
                raise ValueError(
                    f"n_artillery_batteries must be >= 1, "
                    f"got {n_artillery_batteries!r}"
                )
            self.n_artillery_batteries: int = int(n_artillery_batteries)
            self._artillery: ArtilleryCorps = ArtilleryCorps.generate_default(
                map_width=self.map_width,
                map_height=self.map_height,
                n_batteries=self.n_artillery_batteries,
                team=0,
                config=art_config,
            )

        # ── Red artillery (for counter-battery targeting) ────────────
        n_red_art = (
            n_red_artillery_batteries
            if n_red_artillery_batteries is not None
            else self.n_artillery_batteries
        )
        if int(n_red_art) < 1:
            raise ValueError(
                f"n_red_artillery_batteries must be >= 1, got {n_red_art!r}"
            )
        self._red_artillery: ArtilleryCorps = ArtilleryCorps.generate_default(
            map_width=self.map_width,
            map_height=self.map_height,
            n_batteries=int(n_red_art),
            team=1,
            config=ArtilleryUnitConfig(team=1),
        )

        self._last_art_report: ArtilleryReport = ArtilleryReport(
            morale_damage_dealt=0.0,
            guns_silenced=0,
            fortification_damage=0.0,
            fortifications_completed=0,
        )

        # ── Override action space ────────────────────────────────────
        self.action_space = spaces.MultiDiscrete(
            [self.n_corps_options] * self.n_divisions
            + [N_ARTILLERY_MISSIONS] * self.n_artillery_batteries,
            dtype=np.int64,
        )

        # ── Override observation space ───────────────────────────────
        obs_low, obs_high = self._build_art_obs_bounds()
        self.observation_space = spaces.Box(
            low=obs_low, high=obs_high, dtype=np.float32
        )

    # ------------------------------------------------------------------
    # Gymnasium API: reset
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ) -> tuple[np.ndarray, dict]:
        """Reset the environment and return the initial artillery corps observation.

        Calls :meth:`~envs.corps_env.CorpsEnv.reset` on the base env,
        resets artillery positions, and returns the extended observation.
        """
        _, info = super().reset(seed=seed, options=options)

        # Reset Blue artillery to default positions
        for i, unit in enumerate(
            self._artillery.units[: self.n_artillery_batteries]
        ):
            unit.x = self.map_width * 0.25
            unit.y = (
                self.map_height
                * (i + 1)
                / (self.n_artillery_batteries + 1)
            )
            unit.theta = 0.0
            unit.strength = 1.0
            unit.alive = True
            unit.mission = ArtilleryMission.IDLE
            unit._fortify_progress = 0
        self._artillery.fortifications.clear()

        # Reset Red artillery
        for i, unit in enumerate(self._red_artillery.units):
            unit.x = self.map_width * 0.75
            unit.y = (
                self.map_height
                * (i + 1)
                / (len(self._red_artillery.units) + 1)
            )
            unit.theta = math.pi
            unit.strength = 1.0
            unit.alive = True
            unit.mission = ArtilleryMission.IDLE
            unit._fortify_progress = 0
        self._red_artillery.fortifications.clear()

        self._last_art_report = ArtilleryReport(
            morale_damage_dealt=0.0,
            guns_silenced=0,
            fortification_damage=0.0,
            fortifications_completed=0,
        )
        return self._build_art_obs(), info

    # ------------------------------------------------------------------
    # Gymnasium API: step
    # ------------------------------------------------------------------

    def step(
        self, action: np.ndarray
    ) -> tuple[np.ndarray, float, bool, bool, dict]:
        """Execute one artillery-corps macro-step.

        Splits the combined action into corps commands and artillery mission
        assignments.  Infantry and supply logic executes first (via the
        base :class:`~envs.corps_env.CorpsEnv`); artillery then acts on the
        updated battlefield state.

        Parameters
        ----------
        action:
            Integer array of shape
            ``(n_divisions + n_artillery_batteries,)``.
            Elements ``[:n_divisions]`` are standard corps operational
            commands.  Elements ``[n_divisions:]`` assign an
            :class:`~envs.sim.artillery_corps.ArtilleryMission` to each
            battery.

        Returns
        -------
        obs : np.ndarray — extended artillery corps observation
        reward : float — base corps reward + fortification reward shaping
        terminated : bool
        truncated : bool
        info : dict — base corps info plus ``"artillery"`` sub-dict
        """
        action = np.asarray(action, dtype=np.int64)
        expected_len = self.n_divisions + self.n_artillery_batteries
        if action.shape != (expected_len,):
            raise ValueError(
                f"action has shape {action.shape!r}, "
                f"expected ({expected_len},)."
            )

        # Split action
        corps_action = action[: self.n_divisions]
        art_action = action[self.n_divisions :]

        # Assign Blue artillery missions for this step
        for i, unit in enumerate(
            self._artillery.units[: self.n_artillery_batteries]
        ):
            if unit.alive:
                unit.mission = ArtilleryMission(int(art_action[i]))
            else:
                unit.mission = ArtilleryMission.IDLE

        # ── Execute base corps step (infantry + supply) ───────────────
        _, base_reward, terminated, truncated, info = super().step(corps_action)

        # ── Execute Blue artillery step ───────────────────────────────
        inner = self._division._brigade._inner
        self._last_art_report = self._artillery.step(
            inner,
            enemy_artillery=list(self._red_artillery.units),
            enemy_fortifications=list(self._red_artillery.fortifications),
        )

        # ── Reward shaping: bonus for completing fortifications ───────
        fort_bonus = self._last_art_report.fortifications_completed * _FORT_COMPLETION_BONUS
        total_reward = base_reward + fort_bonus

        # ── Augment info with artillery report ────────────────────────
        info["artillery"] = {
            "morale_damage_dealt": self._last_art_report.morale_damage_dealt,
            "guns_silenced": self._last_art_report.guns_silenced,
            "fortification_damage": self._last_art_report.fortification_damage,
            "fortifications_completed": self._last_art_report.fortifications_completed,
            "n_blue_forts": len(self._artillery.fortifications),
            "n_red_forts": len(self._red_artillery.fortifications),
        }

        return (
            self._build_art_obs(),
            total_reward,
            terminated,
            truncated,
            info,
        )

    # ------------------------------------------------------------------
    # Observation construction
    # ------------------------------------------------------------------

    def _build_art_obs(self) -> np.ndarray:
        """Build and return the full artillery corps observation vector.

        Concatenates:
        1. Base corps observation.
        2. Per-artillery-battery state.
        3. Artillery operational summary.
        4. Fortification slot states (Blue forts, zero-padded).
        """
        base_obs: np.ndarray = self._get_corps_obs()

        # ── Per-battery state ─────────────────────────────────────────
        art_parts: List[float] = []
        for unit in self._artillery.units[: self.n_artillery_batteries]:
            if unit.alive:
                art_parts.extend(
                    [
                        unit.x / self.map_width,
                        unit.y / self.map_height,
                        float(int(unit.mission)) / (N_ARTILLERY_MISSIONS - 1),
                        unit.strength,
                    ]
                )
            else:
                art_parts.extend([0.0, 0.0, 0.0, 0.0])

        # ── Artillery operational summary ─────────────────────────────
        r = self._last_art_report
        art_parts.extend(
            [
                min(1.0, r.morale_damage_dealt / _MORALE_DMG_NORM),
                min(1.0, float(r.guns_silenced) / max(1, len(self._red_artillery.units))),
                min(1.0, r.fortification_damage / _FORT_DMG_NORM),
                min(1.0, float(r.fortifications_completed)),
            ]
        )

        # ── Fortification slots (n_artillery_batteries slots, Blue forts) ──
        forts = self._artillery.fortifications
        for slot in range(self.n_artillery_batteries):
            if slot < len(forts):
                fort = forts[slot]
                art_parts.extend(
                    [
                        fort.x / self.map_width,
                        fort.y / self.map_height,
                        fort.hp,
                    ]
                )
            else:
                art_parts.extend([0.0, 0.0, 0.0])

        art_arr = np.array(art_parts, dtype=np.float32)
        obs = np.concatenate([base_obs, art_arr])
        return np.clip(obs, self.observation_space.low, self.observation_space.high)

    # ------------------------------------------------------------------
    # Observation bounds
    # ------------------------------------------------------------------

    def _build_art_obs_bounds(
        self,
    ) -> tuple[np.ndarray, np.ndarray]:
        """Return ``(low, high)`` bounds for the full artillery corps observation."""
        base_low, base_high = self._build_obs_bounds()

        # Extra dims: all in [0, 1]
        n_extra = (
            self.n_artillery_batteries * _ART_UNIT_OBS_DIM
            + _ART_SUMMARY_DIM
            + self.n_artillery_batteries * _ART_FORT_OBS_DIM
        )
        extra_low = np.zeros(n_extra, dtype=np.float32)
        extra_high = np.ones(n_extra, dtype=np.float32)

        return (
            np.concatenate([base_low, extra_low]),
            np.concatenate([base_high, extra_high]),
        )

reset(seed=None, options=None)

Reset the environment and return the initial artillery corps observation.

Calls :meth:~envs.corps_env.CorpsEnv.reset on the base env, resets artillery positions, and returns the extended observation.

Source code in envs/artillery_corps_env.py
def reset(
    self,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> tuple[np.ndarray, dict]:
    """Reset the environment and return the initial artillery corps observation.

    Calls :meth:`~envs.corps_env.CorpsEnv.reset` on the base env,
    resets artillery positions, and returns the extended observation.
    """
    _, info = super().reset(seed=seed, options=options)

    # Reset Blue artillery to default positions
    for i, unit in enumerate(
        self._artillery.units[: self.n_artillery_batteries]
    ):
        unit.x = self.map_width * 0.25
        unit.y = (
            self.map_height
            * (i + 1)
            / (self.n_artillery_batteries + 1)
        )
        unit.theta = 0.0
        unit.strength = 1.0
        unit.alive = True
        unit.mission = ArtilleryMission.IDLE
        unit._fortify_progress = 0
    self._artillery.fortifications.clear()

    # Reset Red artillery
    for i, unit in enumerate(self._red_artillery.units):
        unit.x = self.map_width * 0.75
        unit.y = (
            self.map_height
            * (i + 1)
            / (len(self._red_artillery.units) + 1)
        )
        unit.theta = math.pi
        unit.strength = 1.0
        unit.alive = True
        unit.mission = ArtilleryMission.IDLE
        unit._fortify_progress = 0
    self._red_artillery.fortifications.clear()

    self._last_art_report = ArtilleryReport(
        morale_damage_dealt=0.0,
        guns_silenced=0,
        fortification_damage=0.0,
        fortifications_completed=0,
    )
    return self._build_art_obs(), info

step(action)

Execute one artillery-corps macro-step.

Splits the combined action into corps commands and artillery mission assignments. Infantry and supply logic executes first (via the base :class:~envs.corps_env.CorpsEnv); artillery then acts on the updated battlefield state.

Parameters:

Name Type Description Default
action ndarray

Integer array of shape (n_divisions + n_artillery_batteries,). Elements [:n_divisions] are standard corps operational commands. Elements [n_divisions:] assign an :class:~envs.sim.artillery_corps.ArtilleryMission to each battery.

required

Returns:

Name Type Description
obs np.ndarray — extended artillery corps observation
reward float — base corps reward + fortification reward shaping
terminated bool
truncated bool
info dict — base corps info plus ``"artillery"`` sub-dict
Source code in envs/artillery_corps_env.py
def step(
    self, action: np.ndarray
) -> tuple[np.ndarray, float, bool, bool, dict]:
    """Execute one artillery-corps macro-step.

    Splits the combined action into corps commands and artillery mission
    assignments.  Infantry and supply logic executes first (via the
    base :class:`~envs.corps_env.CorpsEnv`); artillery then acts on the
    updated battlefield state.

    Parameters
    ----------
    action:
        Integer array of shape
        ``(n_divisions + n_artillery_batteries,)``.
        Elements ``[:n_divisions]`` are standard corps operational
        commands.  Elements ``[n_divisions:]`` assign an
        :class:`~envs.sim.artillery_corps.ArtilleryMission` to each
        battery.

    Returns
    -------
    obs : np.ndarray — extended artillery corps observation
    reward : float — base corps reward + fortification reward shaping
    terminated : bool
    truncated : bool
    info : dict — base corps info plus ``"artillery"`` sub-dict
    """
    action = np.asarray(action, dtype=np.int64)
    expected_len = self.n_divisions + self.n_artillery_batteries
    if action.shape != (expected_len,):
        raise ValueError(
            f"action has shape {action.shape!r}, "
            f"expected ({expected_len},)."
        )

    # Split action
    corps_action = action[: self.n_divisions]
    art_action = action[self.n_divisions :]

    # Assign Blue artillery missions for this step
    for i, unit in enumerate(
        self._artillery.units[: self.n_artillery_batteries]
    ):
        if unit.alive:
            unit.mission = ArtilleryMission(int(art_action[i]))
        else:
            unit.mission = ArtilleryMission.IDLE

    # ── Execute base corps step (infantry + supply) ───────────────
    _, base_reward, terminated, truncated, info = super().step(corps_action)

    # ── Execute Blue artillery step ───────────────────────────────
    inner = self._division._brigade._inner
    self._last_art_report = self._artillery.step(
        inner,
        enemy_artillery=list(self._red_artillery.units),
        enemy_fortifications=list(self._red_artillery.fortifications),
    )

    # ── Reward shaping: bonus for completing fortifications ───────
    fort_bonus = self._last_art_report.fortifications_completed * _FORT_COMPLETION_BONUS
    total_reward = base_reward + fort_bonus

    # ── Augment info with artillery report ────────────────────────
    info["artillery"] = {
        "morale_damage_dealt": self._last_art_report.morale_damage_dealt,
        "guns_silenced": self._last_art_report.guns_silenced,
        "fortification_damage": self._last_art_report.fortification_damage,
        "fortifications_completed": self._last_art_report.fortifications_completed,
        "n_blue_forts": len(self._artillery.fortifications),
        "n_red_forts": len(self._red_artillery.fortifications),
    }

    return (
        self._build_art_obs(),
        total_reward,
        terminated,
        truncated,
        info,
    )

envs.multi_battalion_env.MultiBattalionEnv

Bases: ParallelEnv

PettingZoo ParallelEnv for NvN battalion combat.

Parameters:

Name Type Description Default
n_blue int

Number of Blue (team 0) battalions. Must be ≥ 1.

2
n_red int

Number of Red (team 1) battalions. Must be ≥ 1.

2
map_width float

Map dimensions in metres (default 1 km × 1 km).

MAP_WIDTH
map_height float

Map dimensions in metres (default 1 km × 1 km).

MAP_WIDTH
max_steps int

Episode length cap (default 500).

MAX_STEPS
terrain Optional[TerrainMap]

Optional fixed :class:~envs.sim.terrain.TerrainMap. When supplied randomize_terrain is forced to False and this map is used for every episode.

None
randomize_terrain bool

When True (default, if no fixed terrain is supplied) a new procedural terrain is generated from the seeded RNG at the start of each episode.

True
hill_speed_factor float

Movement speed multiplier on maximum-elevation terrain. Must be in (0, 1]; 0.5 means half speed on the highest hills.

0.5
visibility_radius float

Fog-of-war cutoff distance (metres). Enemy units beyond this distance have their strength and morale hidden in observations.

VISIBILITY_RADIUS
render_mode Optional[str]

Currently unused; reserved for future rendering support.

None
Source code in envs/multi_battalion_env.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
class MultiBattalionEnv(ParallelEnv):
    """PettingZoo ``ParallelEnv`` for NvN battalion combat.

    Parameters
    ----------
    n_blue:
        Number of Blue (team 0) battalions.  Must be ``≥ 1``.
    n_red:
        Number of Red (team 1) battalions.  Must be ``≥ 1``.
    map_width, map_height:
        Map dimensions in metres (default 1 km × 1 km).
    max_steps:
        Episode length cap (default 500).
    terrain:
        Optional fixed :class:`~envs.sim.terrain.TerrainMap`.  When
        supplied *randomize_terrain* is forced to ``False`` and this map
        is used for every episode.
    randomize_terrain:
        When ``True`` (default, if no fixed *terrain* is supplied) a new
        procedural terrain is generated from the seeded RNG at the start
        of each episode.
    hill_speed_factor:
        Movement speed multiplier on maximum-elevation terrain.  Must be
        in ``(0, 1]``; ``0.5`` means half speed on the highest hills.
    visibility_radius:
        Fog-of-war cutoff distance (metres).  Enemy units beyond this
        distance have their strength and morale hidden in observations.
    render_mode:
        Currently unused; reserved for future rendering support.
    """

    metadata: dict = {"render_modes": [], "name": "multi_battalion_v0"}

    def __init__(
        self,
        n_blue: int = 2,
        n_red: int = 2,
        map_width: float = MAP_WIDTH,
        map_height: float = MAP_HEIGHT,
        max_steps: int = MAX_STEPS,
        terrain: Optional[TerrainMap] = None,
        randomize_terrain: bool = True,
        hill_speed_factor: float = 0.5,
        visibility_radius: float = VISIBILITY_RADIUS,
        render_mode: Optional[str] = None,
    ) -> None:
        # ------------------------------------------------------------------
        # Argument validation
        # ------------------------------------------------------------------
        if int(n_blue) < 1:
            raise ValueError(f"n_blue must be >= 1, got {n_blue}")
        if int(n_red) < 1:
            raise ValueError(f"n_red must be >= 1, got {n_red}")
        if float(map_width) <= 0:
            raise ValueError(f"map_width must be positive, got {map_width}")
        if float(map_height) <= 0:
            raise ValueError(f"map_height must be positive, got {map_height}")
        if int(max_steps) < 1:
            raise ValueError(f"max_steps must be >= 1, got {max_steps}")
        if not (0.0 < float(hill_speed_factor) <= 1.0):
            raise ValueError(
                f"hill_speed_factor must be in (0, 1], got {hill_speed_factor}"
            )
        if float(visibility_radius) <= 0:
            raise ValueError(
                f"visibility_radius must be positive, got {visibility_radius}"
            )

        self.n_blue = int(n_blue)
        self.n_red = int(n_red)
        self.map_width = float(map_width)
        self.map_height = float(map_height)
        self.map_diagonal = math.hypot(self.map_width, self.map_height)
        self.max_steps = int(max_steps)
        self.hill_speed_factor = float(hill_speed_factor)
        self.visibility_radius = float(visibility_radius)
        self.randomize_terrain = bool(randomize_terrain) and (terrain is None)
        self._supplied_terrain: Optional[TerrainMap] = terrain
        self.terrain: TerrainMap = (
            terrain if terrain is not None
            else TerrainMap.flat(map_width, map_height)
        )
        self.render_mode = render_mode

        # Optional road network — when set, battalions on roads gain a
        # movement-speed bonus (see :data:`~envs.sim.road_network.ROAD_SPEED_BONUS`).
        # Can be set after construction: ``env.road_network = my_network``.
        self.road_network: Optional[RoadNetwork] = None

        # ------------------------------------------------------------------
        # PettingZoo required: possible_agents (fixed for the lifetime of env)
        # ------------------------------------------------------------------
        self.possible_agents: list[str] = (
            [f"blue_{i}" for i in range(self.n_blue)]
            + [f"red_{i}" for i in range(self.n_red)]
        )

        # ------------------------------------------------------------------
        # Observation / action space construction
        # ------------------------------------------------------------------
        # obs_dim = 6 (self) + 5 * (n_total - 1) (others) + 1 (step_norm)
        n_total = self.n_blue + self.n_red
        self._obs_dim: int = 6 + 5 * (n_total - 1) + 1

        # Global state dim: 6 per agent + 1 step_norm
        self._state_dim: int = 6 * n_total + 1

        # Per-unit observation bounds
        # Self state: [x/w, y/h, cos, sin, strength, morale]
        self_low = np.array([0.0, 0.0, -1.0, -1.0, 0.0, 0.0], dtype=np.float32)
        self_high = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32)
        # Other unit block: [dist_norm, cos_bearing, sin_bearing, strength, morale]
        other_unit_low = np.array([0.0, -1.0, -1.0, 0.0, 0.0], dtype=np.float32)
        other_unit_high = np.array([1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32)
        n_others = n_total - 1
        obs_low = np.concatenate(
            [self_low, np.tile(other_unit_low, n_others),
             np.array([0.0], dtype=np.float32)]
        )
        obs_high = np.concatenate(
            [self_high, np.tile(other_unit_high, n_others),
             np.array([1.0], dtype=np.float32)]
        )
        # Single shared observation space instance (same for every agent)
        self._obs_space = spaces.Box(low=obs_low, high=obs_high, dtype=np.float32)
        # Single shared action space instance
        self._act_space = spaces.Box(
            low=np.array([-1.0, -1.0, 0.0], dtype=np.float32),
            high=np.array([1.0, 1.0, 1.0], dtype=np.float32),
            dtype=np.float32,
        )

        # ------------------------------------------------------------------
        # Internal episode state (populated by reset())
        # ------------------------------------------------------------------
        self._battalions: dict[str, Battalion] = {}
        self._combat_states: dict[str, CombatState] = {}
        self._alive: set[str] = set()   # set of agent IDs still alive
        self._step_count: int = 0
        self._rng: np.random.Generator = np.random.default_rng()

        # PettingZoo: live agents list (modified by reset/step)
        self.agents: list[str] = []

    # ------------------------------------------------------------------
    # PettingZoo API: spaces (must return the same object each call)
    # ------------------------------------------------------------------

    def observation_space(self, agent: str) -> spaces.Box:
        """Return observation space for *agent* (same object every call)."""
        return self._obs_space

    def action_space(self, agent: str) -> spaces.Box:
        """Return action space for *agent* (same object every call)."""
        return self._act_space

    # ------------------------------------------------------------------
    # PettingZoo API: reset
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ) -> tuple[dict[str, np.ndarray], dict[str, dict]]:
        """Reset the environment and return initial observations.

        Parameters
        ----------
        seed:
            RNG seed.  Passing the same seed always produces the same
            terrain layout and starting positions.
        options:
            Currently unused; accepted for API compatibility.

        Returns
        -------
        observations : dict[agent_id, np.ndarray]
        infos        : dict[agent_id, dict]
        """
        self._rng = np.random.default_rng(seed)

        # Terrain
        if self.randomize_terrain:
            self.terrain = TerrainMap.generate_random(
                rng=self._rng,
                width=self.map_width,
                height=self.map_height,
            )
        elif self._supplied_terrain is not None:
            self.terrain = self._supplied_terrain

        # Reset live agents
        self.agents = list(self.possible_agents)
        self._alive = set(self.possible_agents)
        self._step_count = 0
        self._battalions = {}
        self._combat_states = {}

        # Spawn Blue agents in the western half, facing roughly east
        for i in range(self.n_blue):
            agent_id = f"blue_{i}"
            x = float(self._rng.uniform(0.1 * self.map_width, 0.4 * self.map_width))
            y = float(self._rng.uniform(0.1 * self.map_height, 0.9 * self.map_height))
            theta = float(self._rng.uniform(-math.pi / 4, math.pi / 4))
            self._battalions[agent_id] = Battalion(
                x=x, y=y, theta=theta, strength=1.0, team=0
            )
            self._combat_states[agent_id] = CombatState()

        # Spawn Red agents in the eastern half, facing roughly west
        for i in range(self.n_red):
            agent_id = f"red_{i}"
            x = float(self._rng.uniform(0.6 * self.map_width, 0.9 * self.map_width))
            y = float(self._rng.uniform(0.1 * self.map_height, 0.9 * self.map_height))
            theta = float(math.pi + self._rng.uniform(-math.pi / 4, math.pi / 4))
            self._battalions[agent_id] = Battalion(
                x=x, y=y, theta=theta, strength=1.0, team=1
            )
            self._combat_states[agent_id] = CombatState()

        observations = {agent: self._get_obs(agent) for agent in self.agents}
        infos: dict[str, dict] = {agent: {} for agent in self.agents}
        return observations, infos

    # ------------------------------------------------------------------
    # PettingZoo API: step
    # ------------------------------------------------------------------

    def step(
        self,
        actions: dict[str, np.ndarray],
    ) -> tuple[
        dict[str, np.ndarray],
        dict[str, float],
        dict[str, bool],
        dict[str, bool],
        dict[str, dict],
    ]:
        """Advance the environment one step.

        Parameters
        ----------
        actions:
            Dict mapping each live agent ID to its action array of shape
            ``(3,)``: ``[move, rotate, fire]``.  Agents absent from the
            dict are treated as no-op (zeros).

        Returns
        -------
        observations, rewards, terminations, truncations, infos
            All keyed by agent ID; contain entries for every agent that
            was alive at the **start** of this step.
        """
        if not self.agents:
            return {}, {}, {}, {}, {}

        current_agents = list(self.agents)

        # --- 1. Apply movement for each live agent ---
        for agent_id in current_agents:
            action = np.asarray(
                actions.get(agent_id, np.zeros(3, dtype=np.float32)),
                dtype=np.float32,
            )
            move_cmd = float(np.clip(action[0], -1.0, 1.0))
            rotate_cmd = float(np.clip(action[1], -1.0, 1.0))

            battalion = self._battalions[agent_id]
            battalion.rotate(rotate_cmd * battalion.max_turn_rate)
            terrain_mod = self.terrain.get_speed_modifier(
                battalion.x, battalion.y, self.hill_speed_factor
            )
            road_mod = 1.0
            if self.road_network is not None:
                road_mod = self.road_network.get_speed_modifier(battalion.x, battalion.y)
            speed_mod = terrain_mod * road_mod
            vx = math.cos(battalion.theta) * move_cmd * battalion.max_speed * speed_mod
            vy = math.sin(battalion.theta) * move_cmd * battalion.max_speed * speed_mod
            # Battalion.move() clamps velocity magnitude to battalion.max_speed.
            # Temporarily raise the effective cap so road bonuses > 1.0 take effect.
            original_max_speed = battalion.max_speed
            try:
                battalion.max_speed = original_max_speed * max(1.0, road_mod)
                battalion.move(vx, vy, dt=DT)
            finally:
                battalion.max_speed = original_max_speed
            battalion.x = float(np.clip(battalion.x, 0.0, self.map_width))
            battalion.y = float(np.clip(battalion.y, 0.0, self.map_height))

        # --- 2. Combat resolution (simultaneous) ---
        for cs in self._combat_states.values():
            cs.reset_step_accumulators()

        blue_agents = [a for a in current_agents if a.startswith("blue_")]
        red_agents = [a for a in current_agents if a.startswith("red_")]

        # Compute all raw damages first (simultaneous resolution)
        # raw_damage[attacker_id][target_id] = raw damage value
        raw_damage: dict[str, dict[str, float]] = {a: {} for a in current_agents}

        for attacker_id in blue_agents:
            fire_cmd = float(
                np.clip(
                    actions.get(attacker_id, np.zeros(3, dtype=np.float32))[2],
                    0.0, 1.0,
                )
            )
            for target_id in red_agents:
                raw = compute_fire_damage(
                    self._battalions[attacker_id],
                    self._battalions[target_id],
                    intensity=fire_cmd,
                )
                raw = self.terrain.apply_cover_modifier(
                    self._battalions[target_id].x,
                    self._battalions[target_id].y,
                    raw,
                )
                raw_damage[attacker_id][target_id] = raw

        for attacker_id in red_agents:
            fire_cmd = float(
                np.clip(
                    actions.get(attacker_id, np.zeros(3, dtype=np.float32))[2],
                    0.0, 1.0,
                )
            )
            for target_id in blue_agents:
                raw = compute_fire_damage(
                    self._battalions[attacker_id],
                    self._battalions[target_id],
                    intensity=fire_cmd,
                )
                raw = self.terrain.apply_cover_modifier(
                    self._battalions[target_id].x,
                    self._battalions[target_id].y,
                    raw,
                )
                raw_damage[attacker_id][target_id] = raw

        # Sum incoming damage per target and apply once (preserves simultaneity)
        damage_dealt: dict[str, float] = {a: 0.0 for a in current_agents}
        damage_received: dict[str, float] = {a: 0.0 for a in current_agents}

        for target_id in current_agents:
            total_raw: float = sum(
                raw_damage[att].get(target_id, 0.0)
                for att in current_agents
                if att != target_id
            )
            if total_raw <= 0.0:
                continue

            actual = apply_casualties(
                self._battalions[target_id],
                self._combat_states[target_id],
                total_raw,
            )
            damage_received[target_id] = actual

            # Credit each attacker proportionally
            for attacker_id in current_agents:
                attacker_raw = raw_damage[attacker_id].get(target_id, 0.0)
                if attacker_raw > 0.0:
                    damage_dealt[attacker_id] += actual * (attacker_raw / total_raw)

        # --- 3. Morale checks ---
        for agent_id in current_agents:
            morale_check(self._combat_states[agent_id], rng=self._rng)
            # Sync Battalion fields from CombatState
            battalion = self._battalions[agent_id]
            cs = self._combat_states[agent_id]
            battalion.morale = cs.morale
            battalion.routed = cs.is_routing

        self._step_count += 1

        # --- 4. Termination / truncation ---
        # Individual termination: routed or effectively destroyed
        individually_done: dict[str, bool] = {
            agent_id: (
                self._combat_states[agent_id].is_routing
                or self._battalions[agent_id].strength <= DESTROYED_THRESHOLD
            )
            for agent_id in current_agents
        }

        # Team-level: surviving-side members are also terminated when the
        # opposing team is completely eliminated this step.
        blue_all_done = all(individually_done.get(a, True) for a in blue_agents)
        red_all_done = all(individually_done.get(a, True) for a in red_agents)

        terminated: dict[str, bool] = {}
        truncated: dict[str, bool] = {}
        for agent_id in current_agents:
            # Terminate if individually eliminated OR the battle is decided
            terminated[agent_id] = (
                individually_done[agent_id]
                or (blue_all_done and bool(blue_agents))
                or (red_all_done and bool(red_agents))
            )
            # Truncate remaining live agents at max_steps
            truncated[agent_id] = (
                not terminated[agent_id]
                and self._step_count >= self.max_steps
            )

        # --- 5. Rewards ---
        rewards: dict[str, float] = {}
        for agent_id in current_agents:
            is_blue = agent_id.startswith("blue_")
            r = (
                damage_dealt[agent_id] * 5.0
                - damage_received[agent_id] * 5.0
                - 0.01  # time penalty
            )
            if terminated[agent_id] and not truncated[agent_id]:
                if is_blue:
                    if red_all_done and not blue_all_done:
                        r += 10.0   # blue wins
                    elif blue_all_done and not red_all_done:
                        r -= 10.0   # blue loses
                else:
                    if blue_all_done and not red_all_done:
                        r += 10.0   # red wins
                    elif red_all_done and not blue_all_done:
                        r -= 10.0   # red loses
            rewards[agent_id] = float(r)

        # --- 6. Update _alive so _get_obs sees newly-dead agents as zero blocks ---
        self._alive -= {
            a for a in current_agents if terminated[a] or truncated[a]
        }

        # --- 7. Build observations (after updating _alive so dead-agent blocks
        #         are correctly zeroed in the observations of surviving agents) ---
        observations = {agent: self._get_obs(agent) for agent in current_agents}
        infos: dict[str, dict] = {
            agent: {
                "damage_dealt": float(damage_dealt[agent]),
                "damage_received": float(damage_received[agent]),
                "step_count": self._step_count,
            }
            for agent in current_agents
        }

        # --- 8. Update live-agents list ---
        self.agents = [
            a for a in self.agents
            if not terminated.get(a, False) and not truncated.get(a, False)
        ]

        return observations, rewards, terminated, truncated, infos

    # ------------------------------------------------------------------
    # PettingZoo API: state
    # ------------------------------------------------------------------

    def state(self) -> np.ndarray:
        """Return the global state tensor for centralized critics.

        Concatenates ``[x/w, y/h, cos θ, sin θ, strength, morale]`` for
        every agent in ``possible_agents`` order, followed by
        ``step / max_steps``.  Dead-agent slots are all-zero.

        Returns
        -------
        np.ndarray of shape ``(6 * (n_blue + n_red) + 1,)`` and dtype
        ``float32``.
        """
        parts: list[float] = []
        for agent_id in self.possible_agents:
            if agent_id in self._battalions and agent_id in self._alive:
                b = self._battalions[agent_id]
                parts.extend([
                    b.x / self.map_width,
                    b.y / self.map_height,
                    math.cos(b.theta),
                    math.sin(b.theta),
                    float(b.strength),
                    float(b.morale),
                ])
            else:
                parts.extend([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        parts.append(min(self._step_count / self.max_steps, 1.0))
        return np.array(parts, dtype=np.float32)

    def get_coordination_metrics(
        self,
        support_radius: float = 300.0,
    ) -> dict[str, float]:
        """Compute per-step coordination metrics for the current environment state.

        Calls :func:`~envs.metrics.coordination.compute_all` on the currently
        alive and non-routed Blue and Red battalions.

        Parameters
        ----------
        support_radius:
            Distance threshold (metres) for
            :func:`~envs.metrics.coordination.mutual_support_score`.

        Returns
        -------
        dict[str, float]
            Keys: ``"coordination/flanking_ratio"``,
            ``"coordination/fire_concentration"``,
            ``"coordination/mutual_support_score"``.
        """
        blue = [
            b
            for agent_id, b in self._battalions.items()
            if agent_id.startswith("blue_") and not b.routed and b.strength > DESTROYED_THRESHOLD
        ]
        red = [
            b
            for agent_id, b in self._battalions.items()
            if agent_id.startswith("red_") and not b.routed and b.strength > DESTROYED_THRESHOLD
        ]
        return _compute_coordination(blue, red, support_radius=support_radius)

    # ------------------------------------------------------------------
    # PettingZoo API: close / render (stubs)
    # ------------------------------------------------------------------

    def close(self) -> None:
        """Clean up resources."""

    def render(self) -> None:
        """Rendering is not yet implemented."""

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _get_obs(self, agent_id: str) -> np.ndarray:
        """Build the normalised observation vector for *agent_id*.

        Layout: self(6) | allies(5 each) | enemies(5 each) | step_norm(1)

        Allies are ordered by agent index; enemies likewise.
        Dead agents and fog-of-war enemies produce zero/clamped blocks.
        """
        b = self._battalions[agent_id]
        is_blue = agent_id.startswith("blue_")

        # Self state
        self_state = np.array(
            [
                b.x / self.map_width,
                b.y / self.map_height,
                math.cos(b.theta),
                math.sin(b.theta),
                float(b.strength),
                float(b.morale),
            ],
            dtype=np.float32,
        )

        # Other units: allies first (own team minus self), then enemies
        ally_prefix = "blue_" if is_blue else "red_"
        enemy_prefix = "red_" if is_blue else "blue_"

        allies = [
            a for a in self.possible_agents
            if a != agent_id and a.startswith(ally_prefix)
        ]
        enemies = [
            a for a in self.possible_agents
            if a.startswith(enemy_prefix)
        ]
        other_order = allies + enemies

        other_features: list[float] = []
        for other_id in other_order:
            # Dead unit → all zeros
            if other_id not in self._alive:
                other_features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
                continue

            other = self._battalions[other_id]
            dx = other.x - b.x
            dy = other.y - b.y
            dist = math.sqrt(dx ** 2 + dy ** 2)
            dist_norm = min(dist / self.map_diagonal, 1.0)
            bearing = math.atan2(dy, dx)
            is_enemy = other_id.startswith(enemy_prefix)

            if is_enemy and dist > self.visibility_radius:
                # Fog of war: report max distance, hide state
                other_features.extend([1.0, 0.0, 0.0, 0.0, 0.0])
            else:
                other_features.extend([
                    dist_norm,
                    math.cos(bearing),
                    math.sin(bearing),
                    float(other.strength),
                    float(other.morale),
                ])

        step_norm = min(self._step_count / self.max_steps, 1.0)
        obs = np.concatenate(
            [self_state, np.array(other_features, dtype=np.float32), [step_norm]],
            dtype=np.float32,
        )
        return np.clip(obs, self._obs_space.low, self._obs_space.high)

action_space(agent)

Return action space for agent (same object every call).

Source code in envs/multi_battalion_env.py
def action_space(self, agent: str) -> spaces.Box:
    """Return action space for *agent* (same object every call)."""
    return self._act_space

close()

Clean up resources.

Source code in envs/multi_battalion_env.py
def close(self) -> None:
    """Clean up resources."""

get_coordination_metrics(support_radius=300.0)

Compute per-step coordination metrics for the current environment state.

Calls :func:~envs.metrics.coordination.compute_all on the currently alive and non-routed Blue and Red battalions.

Parameters:

Name Type Description Default
support_radius float

Distance threshold (metres) for :func:~envs.metrics.coordination.mutual_support_score.

300.0

Returns:

Type Description
dict[str, float]

Keys: "coordination/flanking_ratio", "coordination/fire_concentration", "coordination/mutual_support_score".

Source code in envs/multi_battalion_env.py
def get_coordination_metrics(
    self,
    support_radius: float = 300.0,
) -> dict[str, float]:
    """Compute per-step coordination metrics for the current environment state.

    Calls :func:`~envs.metrics.coordination.compute_all` on the currently
    alive and non-routed Blue and Red battalions.

    Parameters
    ----------
    support_radius:
        Distance threshold (metres) for
        :func:`~envs.metrics.coordination.mutual_support_score`.

    Returns
    -------
    dict[str, float]
        Keys: ``"coordination/flanking_ratio"``,
        ``"coordination/fire_concentration"``,
        ``"coordination/mutual_support_score"``.
    """
    blue = [
        b
        for agent_id, b in self._battalions.items()
        if agent_id.startswith("blue_") and not b.routed and b.strength > DESTROYED_THRESHOLD
    ]
    red = [
        b
        for agent_id, b in self._battalions.items()
        if agent_id.startswith("red_") and not b.routed and b.strength > DESTROYED_THRESHOLD
    ]
    return _compute_coordination(blue, red, support_radius=support_radius)

observation_space(agent)

Return observation space for agent (same object every call).

Source code in envs/multi_battalion_env.py
def observation_space(self, agent: str) -> spaces.Box:
    """Return observation space for *agent* (same object every call)."""
    return self._obs_space

render()

Rendering is not yet implemented.

Source code in envs/multi_battalion_env.py
def render(self) -> None:
    """Rendering is not yet implemented."""

reset(seed=None, options=None)

Reset the environment and return initial observations.

Parameters:

Name Type Description Default
seed Optional[int]

RNG seed. Passing the same seed always produces the same terrain layout and starting positions.

None
options Optional[dict]

Currently unused; accepted for API compatibility.

None

Returns:

Name Type Description
observations dict[agent_id, ndarray]
infos dict[agent_id, dict]
Source code in envs/multi_battalion_env.py
def reset(
    self,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> tuple[dict[str, np.ndarray], dict[str, dict]]:
    """Reset the environment and return initial observations.

    Parameters
    ----------
    seed:
        RNG seed.  Passing the same seed always produces the same
        terrain layout and starting positions.
    options:
        Currently unused; accepted for API compatibility.

    Returns
    -------
    observations : dict[agent_id, np.ndarray]
    infos        : dict[agent_id, dict]
    """
    self._rng = np.random.default_rng(seed)

    # Terrain
    if self.randomize_terrain:
        self.terrain = TerrainMap.generate_random(
            rng=self._rng,
            width=self.map_width,
            height=self.map_height,
        )
    elif self._supplied_terrain is not None:
        self.terrain = self._supplied_terrain

    # Reset live agents
    self.agents = list(self.possible_agents)
    self._alive = set(self.possible_agents)
    self._step_count = 0
    self._battalions = {}
    self._combat_states = {}

    # Spawn Blue agents in the western half, facing roughly east
    for i in range(self.n_blue):
        agent_id = f"blue_{i}"
        x = float(self._rng.uniform(0.1 * self.map_width, 0.4 * self.map_width))
        y = float(self._rng.uniform(0.1 * self.map_height, 0.9 * self.map_height))
        theta = float(self._rng.uniform(-math.pi / 4, math.pi / 4))
        self._battalions[agent_id] = Battalion(
            x=x, y=y, theta=theta, strength=1.0, team=0
        )
        self._combat_states[agent_id] = CombatState()

    # Spawn Red agents in the eastern half, facing roughly west
    for i in range(self.n_red):
        agent_id = f"red_{i}"
        x = float(self._rng.uniform(0.6 * self.map_width, 0.9 * self.map_width))
        y = float(self._rng.uniform(0.1 * self.map_height, 0.9 * self.map_height))
        theta = float(math.pi + self._rng.uniform(-math.pi / 4, math.pi / 4))
        self._battalions[agent_id] = Battalion(
            x=x, y=y, theta=theta, strength=1.0, team=1
        )
        self._combat_states[agent_id] = CombatState()

    observations = {agent: self._get_obs(agent) for agent in self.agents}
    infos: dict[str, dict] = {agent: {} for agent in self.agents}
    return observations, infos

state()

Return the global state tensor for centralized critics.

Concatenates [x/w, y/h, cos θ, sin θ, strength, morale] for every agent in possible_agents order, followed by step / max_steps. Dead-agent slots are all-zero.

Returns:

Type Description
np.ndarray of shape ``(6 * (n_blue + n_red) + 1,)`` and dtype
``float32``.
Source code in envs/multi_battalion_env.py
def state(self) -> np.ndarray:
    """Return the global state tensor for centralized critics.

    Concatenates ``[x/w, y/h, cos θ, sin θ, strength, morale]`` for
    every agent in ``possible_agents`` order, followed by
    ``step / max_steps``.  Dead-agent slots are all-zero.

    Returns
    -------
    np.ndarray of shape ``(6 * (n_blue + n_red) + 1,)`` and dtype
    ``float32``.
    """
    parts: list[float] = []
    for agent_id in self.possible_agents:
        if agent_id in self._battalions and agent_id in self._alive:
            b = self._battalions[agent_id]
            parts.extend([
                b.x / self.map_width,
                b.y / self.map_height,
                math.cos(b.theta),
                math.sin(b.theta),
                float(b.strength),
                float(b.morale),
            ])
        else:
            parts.extend([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
    parts.append(min(self._step_count / self.max_steps, 1.0))
    return np.array(parts, dtype=np.float32)

step(actions)

Advance the environment one step.

Parameters:

Name Type Description Default
actions dict[str, ndarray]

Dict mapping each live agent ID to its action array of shape (3,): [move, rotate, fire]. Agents absent from the dict are treated as no-op (zeros).

required

Returns:

Type Description
(observations, rewards, terminations, truncations, infos)

All keyed by agent ID; contain entries for every agent that was alive at the start of this step.

Source code in envs/multi_battalion_env.py
def step(
    self,
    actions: dict[str, np.ndarray],
) -> tuple[
    dict[str, np.ndarray],
    dict[str, float],
    dict[str, bool],
    dict[str, bool],
    dict[str, dict],
]:
    """Advance the environment one step.

    Parameters
    ----------
    actions:
        Dict mapping each live agent ID to its action array of shape
        ``(3,)``: ``[move, rotate, fire]``.  Agents absent from the
        dict are treated as no-op (zeros).

    Returns
    -------
    observations, rewards, terminations, truncations, infos
        All keyed by agent ID; contain entries for every agent that
        was alive at the **start** of this step.
    """
    if not self.agents:
        return {}, {}, {}, {}, {}

    current_agents = list(self.agents)

    # --- 1. Apply movement for each live agent ---
    for agent_id in current_agents:
        action = np.asarray(
            actions.get(agent_id, np.zeros(3, dtype=np.float32)),
            dtype=np.float32,
        )
        move_cmd = float(np.clip(action[0], -1.0, 1.0))
        rotate_cmd = float(np.clip(action[1], -1.0, 1.0))

        battalion = self._battalions[agent_id]
        battalion.rotate(rotate_cmd * battalion.max_turn_rate)
        terrain_mod = self.terrain.get_speed_modifier(
            battalion.x, battalion.y, self.hill_speed_factor
        )
        road_mod = 1.0
        if self.road_network is not None:
            road_mod = self.road_network.get_speed_modifier(battalion.x, battalion.y)
        speed_mod = terrain_mod * road_mod
        vx = math.cos(battalion.theta) * move_cmd * battalion.max_speed * speed_mod
        vy = math.sin(battalion.theta) * move_cmd * battalion.max_speed * speed_mod
        # Battalion.move() clamps velocity magnitude to battalion.max_speed.
        # Temporarily raise the effective cap so road bonuses > 1.0 take effect.
        original_max_speed = battalion.max_speed
        try:
            battalion.max_speed = original_max_speed * max(1.0, road_mod)
            battalion.move(vx, vy, dt=DT)
        finally:
            battalion.max_speed = original_max_speed
        battalion.x = float(np.clip(battalion.x, 0.0, self.map_width))
        battalion.y = float(np.clip(battalion.y, 0.0, self.map_height))

    # --- 2. Combat resolution (simultaneous) ---
    for cs in self._combat_states.values():
        cs.reset_step_accumulators()

    blue_agents = [a for a in current_agents if a.startswith("blue_")]
    red_agents = [a for a in current_agents if a.startswith("red_")]

    # Compute all raw damages first (simultaneous resolution)
    # raw_damage[attacker_id][target_id] = raw damage value
    raw_damage: dict[str, dict[str, float]] = {a: {} for a in current_agents}

    for attacker_id in blue_agents:
        fire_cmd = float(
            np.clip(
                actions.get(attacker_id, np.zeros(3, dtype=np.float32))[2],
                0.0, 1.0,
            )
        )
        for target_id in red_agents:
            raw = compute_fire_damage(
                self._battalions[attacker_id],
                self._battalions[target_id],
                intensity=fire_cmd,
            )
            raw = self.terrain.apply_cover_modifier(
                self._battalions[target_id].x,
                self._battalions[target_id].y,
                raw,
            )
            raw_damage[attacker_id][target_id] = raw

    for attacker_id in red_agents:
        fire_cmd = float(
            np.clip(
                actions.get(attacker_id, np.zeros(3, dtype=np.float32))[2],
                0.0, 1.0,
            )
        )
        for target_id in blue_agents:
            raw = compute_fire_damage(
                self._battalions[attacker_id],
                self._battalions[target_id],
                intensity=fire_cmd,
            )
            raw = self.terrain.apply_cover_modifier(
                self._battalions[target_id].x,
                self._battalions[target_id].y,
                raw,
            )
            raw_damage[attacker_id][target_id] = raw

    # Sum incoming damage per target and apply once (preserves simultaneity)
    damage_dealt: dict[str, float] = {a: 0.0 for a in current_agents}
    damage_received: dict[str, float] = {a: 0.0 for a in current_agents}

    for target_id in current_agents:
        total_raw: float = sum(
            raw_damage[att].get(target_id, 0.0)
            for att in current_agents
            if att != target_id
        )
        if total_raw <= 0.0:
            continue

        actual = apply_casualties(
            self._battalions[target_id],
            self._combat_states[target_id],
            total_raw,
        )
        damage_received[target_id] = actual

        # Credit each attacker proportionally
        for attacker_id in current_agents:
            attacker_raw = raw_damage[attacker_id].get(target_id, 0.0)
            if attacker_raw > 0.0:
                damage_dealt[attacker_id] += actual * (attacker_raw / total_raw)

    # --- 3. Morale checks ---
    for agent_id in current_agents:
        morale_check(self._combat_states[agent_id], rng=self._rng)
        # Sync Battalion fields from CombatState
        battalion = self._battalions[agent_id]
        cs = self._combat_states[agent_id]
        battalion.morale = cs.morale
        battalion.routed = cs.is_routing

    self._step_count += 1

    # --- 4. Termination / truncation ---
    # Individual termination: routed or effectively destroyed
    individually_done: dict[str, bool] = {
        agent_id: (
            self._combat_states[agent_id].is_routing
            or self._battalions[agent_id].strength <= DESTROYED_THRESHOLD
        )
        for agent_id in current_agents
    }

    # Team-level: surviving-side members are also terminated when the
    # opposing team is completely eliminated this step.
    blue_all_done = all(individually_done.get(a, True) for a in blue_agents)
    red_all_done = all(individually_done.get(a, True) for a in red_agents)

    terminated: dict[str, bool] = {}
    truncated: dict[str, bool] = {}
    for agent_id in current_agents:
        # Terminate if individually eliminated OR the battle is decided
        terminated[agent_id] = (
            individually_done[agent_id]
            or (blue_all_done and bool(blue_agents))
            or (red_all_done and bool(red_agents))
        )
        # Truncate remaining live agents at max_steps
        truncated[agent_id] = (
            not terminated[agent_id]
            and self._step_count >= self.max_steps
        )

    # --- 5. Rewards ---
    rewards: dict[str, float] = {}
    for agent_id in current_agents:
        is_blue = agent_id.startswith("blue_")
        r = (
            damage_dealt[agent_id] * 5.0
            - damage_received[agent_id] * 5.0
            - 0.01  # time penalty
        )
        if terminated[agent_id] and not truncated[agent_id]:
            if is_blue:
                if red_all_done and not blue_all_done:
                    r += 10.0   # blue wins
                elif blue_all_done and not red_all_done:
                    r -= 10.0   # blue loses
            else:
                if blue_all_done and not red_all_done:
                    r += 10.0   # red wins
                elif red_all_done and not blue_all_done:
                    r -= 10.0   # red loses
        rewards[agent_id] = float(r)

    # --- 6. Update _alive so _get_obs sees newly-dead agents as zero blocks ---
    self._alive -= {
        a for a in current_agents if terminated[a] or truncated[a]
    }

    # --- 7. Build observations (after updating _alive so dead-agent blocks
    #         are correctly zeroed in the observations of surviving agents) ---
    observations = {agent: self._get_obs(agent) for agent in current_agents}
    infos: dict[str, dict] = {
        agent: {
            "damage_dealt": float(damage_dealt[agent]),
            "damage_received": float(damage_received[agent]),
            "step_count": self._step_count,
        }
        for agent in current_agents
    }

    # --- 8. Update live-agents list ---
    self.agents = [
        a for a in self.agents
        if not terminated.get(a, False) and not truncated.get(a, False)
    ]

    return observations, rewards, terminated, truncated, infos

Reward shaping

envs.reward.RewardWeights dataclass

Configurable multipliers for each reward component.

All weights are applied by multiplying against their respective raw component value before summing into the total reward. Set a weight to 0.0 to disable that component entirely.

Parameters:

Name Type Description Default
delta_enemy_strength float

Multiplier applied to the fraction of enemy strength destroyed in a step (dmg_b2r). Encourages the agent to deal damage.

5.0
delta_own_strength float

Multiplier applied to the fraction of own strength lost in a step (dmg_r2b). The contribution is negated before summing.

5.0
survival_bonus float

Per-step bonus scaled by Blue's current strength. Set to 0.0 (default) to disable; a small positive value (e.g. 0.005) rewards staying alive longer.

0.0
win_bonus float

Terminal reward added when Blue wins (Red routed or destroyed).

10.0
loss_penalty float

Terminal reward added when Blue loses (Blue routed or destroyed). Should be negative.

-10.0
time_penalty float

Constant added every step. A small negative value (e.g. -0.01) discourages unnecessary stalling.

-0.01
Source code in envs/reward.py
@dataclass
class RewardWeights:
    """Configurable multipliers for each reward component.

    All weights are applied by multiplying against their respective raw
    component value before summing into the total reward.  Set a weight
    to ``0.0`` to disable that component entirely.

    Parameters
    ----------
    delta_enemy_strength:
        Multiplier applied to the fraction of enemy strength destroyed in a
        step (``dmg_b2r``).  Encourages the agent to deal damage.
    delta_own_strength:
        Multiplier applied to the fraction of own strength lost in a step
        (``dmg_r2b``).  The contribution is negated before summing.
    survival_bonus:
        Per-step bonus scaled by Blue's current strength.  Set to ``0.0``
        (default) to disable; a small positive value (e.g. ``0.005``)
        rewards staying alive longer.
    win_bonus:
        Terminal reward added when Blue wins (Red routed or destroyed).
    loss_penalty:
        Terminal reward added when Blue loses (Blue routed or destroyed).
        Should be negative.
    time_penalty:
        Constant added every step.  A small negative value (e.g. ``-0.01``)
        discourages unnecessary stalling.
    """

    delta_enemy_strength: float = 5.0
    delta_own_strength: float = 5.0
    survival_bonus: float = 0.0
    win_bonus: float = 10.0
    loss_penalty: float = -10.0
    time_penalty: float = -0.01
    enemy_routed_bonus: float = 0.0
    """Bonus added each step that the enemy is in a routing state.

    A positive value (e.g. ``2.0``) encourages pursuit of routing enemies and
    exploitation of broken units.  Set to ``0.0`` (default) to disable.
    """
    own_routing_penalty: float = 0.0
    """Penalty added each step that the agent's own unit is routing.

    A negative value (e.g. ``-2.0``) discourages allowing the agent's
    battalion to be broken.  Set to ``0.0`` (default) to disable.
    """

enemy_routed_bonus = 0.0 class-attribute instance-attribute

Bonus added each step that the enemy is in a routing state.

A positive value (e.g. 2.0) encourages pursuit of routing enemies and exploitation of broken units. Set to 0.0 (default) to disable.

own_routing_penalty = 0.0 class-attribute instance-attribute

Penalty added each step that the agent's own unit is routing.

A negative value (e.g. -2.0) discourages allowing the agent's battalion to be broken. Set to 0.0 (default) to disable.

envs.reward.RewardComponents dataclass

Per-component reward breakdown for a single environment step.

Use components.total to obtain the scalar reward to return from env.step(). The individual fields can be logged to W&B for analysis of the learning signal.

Source code in envs/reward.py
@dataclass
class RewardComponents:
    """Per-component reward breakdown for a single environment step.

    Use ``components.total`` to obtain the scalar reward to return from
    ``env.step()``.  The individual fields can be logged to W&B for
    analysis of the learning signal.
    """

    delta_enemy_strength: float = 0.0
    delta_own_strength: float = 0.0
    survival_bonus: float = 0.0
    win_bonus: float = 0.0
    loss_penalty: float = 0.0
    time_penalty: float = 0.0
    enemy_routed_bonus: float = 0.0
    own_routing_penalty: float = 0.0

    @property
    def total(self) -> float:
        """Scalar sum of all reward components."""
        return (
            self.delta_enemy_strength
            + self.delta_own_strength
            + self.survival_bonus
            + self.win_bonus
            + self.loss_penalty
            + self.time_penalty
            + self.enemy_routed_bonus
            + self.own_routing_penalty
        )

    def as_dict(self) -> dict[str, float]:
        """Return a ``{component_name: value}`` mapping suitable for logging."""
        return {
            "reward/delta_enemy_strength": self.delta_enemy_strength,
            "reward/delta_own_strength": self.delta_own_strength,
            "reward/survival_bonus": self.survival_bonus,
            "reward/win_bonus": self.win_bonus,
            "reward/loss_penalty": self.loss_penalty,
            "reward/time_penalty": self.time_penalty,
            "reward/enemy_routed_bonus": self.enemy_routed_bonus,
            "reward/own_routing_penalty": self.own_routing_penalty,
            "reward/total": self.total,
        }

total property

Scalar sum of all reward components.

as_dict()

Return a {component_name: value} mapping suitable for logging.

Source code in envs/reward.py
def as_dict(self) -> dict[str, float]:
    """Return a ``{component_name: value}`` mapping suitable for logging."""
    return {
        "reward/delta_enemy_strength": self.delta_enemy_strength,
        "reward/delta_own_strength": self.delta_own_strength,
        "reward/survival_bonus": self.survival_bonus,
        "reward/win_bonus": self.win_bonus,
        "reward/loss_penalty": self.loss_penalty,
        "reward/time_penalty": self.time_penalty,
        "reward/enemy_routed_bonus": self.enemy_routed_bonus,
        "reward/own_routing_penalty": self.own_routing_penalty,
        "reward/total": self.total,
    }

envs.reward.compute_reward(*, dmg_b2r, dmg_r2b, blue_strength, blue_won, blue_lost, weights, enemy_routed=False, own_routing=False)

Compute all reward components for a single environment step.

Parameters:

Name Type Description Default
dmg_b2r float

Damage dealt by Blue to Red this step (strength-fraction units, typically in [0, 1]).

required
dmg_r2b float

Damage dealt by Red to Blue this step.

required
blue_strength float

Blue's current strength after casualties (used for the survival bonus), in [0, 1].

required
blue_won bool

True when the episode terminates with Red defeated.

required
blue_lost bool

True when the episode terminates with Blue defeated.

required
weights RewardWeights

:class:RewardWeights multipliers for each component.

required
enemy_routed bool

True when the enemy (Red) is currently routing this step. Triggers the enemy_routed_bonus component when non-zero in weights. Defaults to False.

False
own_routing bool

True when the agent's own unit (Blue) is currently routing. Triggers the own_routing_penalty component when non-zero in weights. Defaults to False.

False

Returns:

Type Description
RewardComponents

Individual reward components; call .total for the scalar sum or .as_dict() for a loggable mapping.

Source code in envs/reward.py
def compute_reward(
    *,
    dmg_b2r: float,
    dmg_r2b: float,
    blue_strength: float,
    blue_won: bool,
    blue_lost: bool,
    weights: RewardWeights,
    enemy_routed: bool = False,
    own_routing: bool = False,
) -> RewardComponents:
    """Compute all reward components for a single environment step.

    Parameters
    ----------
    dmg_b2r:
        Damage dealt by Blue to Red this step (strength-fraction units,
        typically in ``[0, 1]``).
    dmg_r2b:
        Damage dealt by Red to Blue this step.
    blue_strength:
        Blue's current strength after casualties (used for the survival
        bonus), in ``[0, 1]``.
    blue_won:
        ``True`` when the episode terminates with Red defeated.
    blue_lost:
        ``True`` when the episode terminates with Blue defeated.
    weights:
        :class:`RewardWeights` multipliers for each component.
    enemy_routed:
        ``True`` when the enemy (Red) is currently routing this step.
        Triggers the ``enemy_routed_bonus`` component when non-zero in
        *weights*.  Defaults to ``False``.
    own_routing:
        ``True`` when the agent's own unit (Blue) is currently routing.
        Triggers the ``own_routing_penalty`` component when non-zero in
        *weights*.  Defaults to ``False``.

    Returns
    -------
    RewardComponents
        Individual reward components; call ``.total`` for the scalar sum
        or ``.as_dict()`` for a loggable mapping.
    """
    comps = RewardComponents()
    comps.delta_enemy_strength = dmg_b2r * weights.delta_enemy_strength
    comps.delta_own_strength = -dmg_r2b * weights.delta_own_strength
    comps.survival_bonus = blue_strength * weights.survival_bonus
    comps.time_penalty = weights.time_penalty
    if blue_won:
        comps.win_bonus = weights.win_bonus
    if blue_lost:
        comps.loss_penalty = weights.loss_penalty
    if enemy_routed:
        comps.enemy_routed_bonus = weights.enemy_routed_bonus
    if own_routing:
        comps.own_routing_penalty = weights.own_routing_penalty
    return comps

Configuration types

envs.battalion_env.LogisticsConfig dataclass

Configurable parameters for the supply, ammunition, and fatigue model.

Pass an instance to :class:~envs.battalion_env.BattalionEnv (via the logistics_config parameter) to enable the full logistics model. All parameters have historically-informed defaults; adjust them to tune scenario difficulty.

Parameters:

Name Type Description Default
initial_ammo float

Starting ammunition level in [0, 1]. 1.0 = full load-out.

DEFAULT_INITIAL_AMMO
initial_food float

Starting food/water level in [0, 1]. 1.0 = full rations.

DEFAULT_INITIAL_FOOD
ammo_per_volley float

Ammunition consumed per unit of fire intensity per step. A value of 0.01 means firing at full intensity for 100 steps exhausts the load-out. Scaled by intensity so partial volleys cost less.

0.01
food_per_step float

Food/water consumed per simulation step regardless of activity. Models the steady drain of rations on campaign.

0.0005
fatigue_per_move_step float

Fatigue accumulated per step when the unit is moving (> 0).

0.002
fatigue_per_fire_step float

Fatigue accumulated per step when the unit is firing (> 0).

0.001
fatigue_recovery_per_halt_step float

Fatigue recovered per step when the unit is stationary and not firing. Should exceed fatigue_per_move_step so that resting eventually restores full readiness.

0.003
resupply_radius float

Distance in metres within which a battalion can draw on a supply wagon. The unit must be within this radius and the wagon must be alive.

100.0
ammo_resupply_rate float

Ammunition recovered per step while within resupply radius.

0.02
food_resupply_rate float

Food/water recovered per step while within resupply radius.

0.01
low_ammo_accuracy_penalty float

Multiplicative fire-intensity modifier applied when ammo is below :data:CRITICAL_AMMO_THRESHOLD. A value of 0.5 halves the effective volley when critically low.

0.5
fatigue_speed_penalty float

Maximum fractional speed reduction at full fatigue (1.0). At fatigue f, speed is multiplied by 1 − fatigue_speed_penalty × max(0, (f − onset) / (1 − onset)).

0.3
fatigue_accuracy_penalty float

Maximum fractional fire-intensity reduction at full fatigue. Applies the same capped formula as fatigue_speed_penalty.

0.2
enable_resupply bool

When False the :func:check_resupply function is a no-op. Useful for short training episodes where supply logistics would end episodes too quickly.

True
wagon_speed float

Maximum movement speed of supply wagons (metres per step at dt=1). Wagons are slow: ~10 m/step vs 50 m/step for infantry.

10.0
wagon_max_strength float

Initial strength of a new supply wagon in [0, 1].

1.0
Source code in envs/sim/logistics.py
@dataclass
class LogisticsConfig:
    """Configurable parameters for the supply, ammunition, and fatigue model.

    Pass an instance to :class:`~envs.battalion_env.BattalionEnv` (via the
    *logistics_config* parameter) to enable the full logistics model.  All
    parameters have historically-informed defaults; adjust them to tune
    scenario difficulty.

    Parameters
    ----------
    initial_ammo:
        Starting ammunition level in ``[0, 1]``.  ``1.0`` = full load-out.
    initial_food:
        Starting food/water level in ``[0, 1]``.  ``1.0`` = full rations.
    ammo_per_volley:
        Ammunition consumed per unit of fire intensity per step.  A value
        of ``0.01`` means firing at full intensity for 100 steps exhausts
        the load-out.  Scaled by intensity so partial volleys cost less.
    food_per_step:
        Food/water consumed per simulation step regardless of activity.
        Models the steady drain of rations on campaign.
    fatigue_per_move_step:
        Fatigue accumulated per step when the unit is moving (``> 0``).
    fatigue_per_fire_step:
        Fatigue accumulated per step when the unit is firing (``> 0``).
    fatigue_recovery_per_halt_step:
        Fatigue recovered per step when the unit is stationary and not
        firing.  Should exceed ``fatigue_per_move_step`` so that resting
        eventually restores full readiness.
    resupply_radius:
        Distance in metres within which a battalion can draw on a supply
        wagon.  The unit must be within this radius *and* the wagon must be
        alive.
    ammo_resupply_rate:
        Ammunition recovered per step while within resupply radius.
    food_resupply_rate:
        Food/water recovered per step while within resupply radius.
    low_ammo_accuracy_penalty:
        Multiplicative fire-intensity modifier applied when ammo is below
        :data:`CRITICAL_AMMO_THRESHOLD`.  A value of ``0.5`` halves the
        effective volley when critically low.
    fatigue_speed_penalty:
        Maximum fractional speed reduction at full fatigue (``1.0``).
        At fatigue ``f``, speed is multiplied by
        ``1 − fatigue_speed_penalty × max(0, (f − onset) / (1 − onset))``.
    fatigue_accuracy_penalty:
        Maximum fractional fire-intensity reduction at full fatigue.
        Applies the same capped formula as *fatigue_speed_penalty*.
    enable_resupply:
        When ``False`` the :func:`check_resupply` function is a no-op.
        Useful for short training episodes where supply logistics would
        end episodes too quickly.
    wagon_speed:
        Maximum movement speed of supply wagons (metres per step at dt=1).
        Wagons are slow: ~10 m/step vs 50 m/step for infantry.
    wagon_max_strength:
        Initial strength of a new supply wagon in ``[0, 1]``.
    """

    # Initial levels
    initial_ammo: float = DEFAULT_INITIAL_AMMO
    initial_food: float = DEFAULT_INITIAL_FOOD

    # Consumption rates
    ammo_per_volley: float = 0.01
    food_per_step: float = 0.0005

    # Fatigue rates
    fatigue_per_move_step: float = 0.002
    fatigue_per_fire_step: float = 0.001
    fatigue_recovery_per_halt_step: float = 0.003

    # Resupply
    resupply_radius: float = 100.0
    ammo_resupply_rate: float = 0.02
    food_resupply_rate: float = 0.01
    enable_resupply: bool = True

    # Performance modifiers
    low_ammo_accuracy_penalty: float = 0.5
    fatigue_speed_penalty: float = 0.3
    fatigue_accuracy_penalty: float = 0.2

    # Supply wagon properties
    wagon_speed: float = 10.0
    wagon_max_strength: float = 1.0

    def __post_init__(self) -> None:
        if not (0.0 <= self.initial_ammo <= 1.0):
            raise ValueError(
                f"initial_ammo must be in [0, 1], got {self.initial_ammo}"
            )
        if not (0.0 <= self.initial_food <= 1.0):
            raise ValueError(
                f"initial_food must be in [0, 1], got {self.initial_food}"
            )
        if self.ammo_per_volley < 0.0:
            raise ValueError(
                f"ammo_per_volley must be >= 0, got {self.ammo_per_volley}"
            )
        if self.food_per_step < 0.0:
            raise ValueError(
                f"food_per_step must be >= 0, got {self.food_per_step}"
            )
        if self.fatigue_per_move_step < 0.0:
            raise ValueError(
                f"fatigue_per_move_step must be >= 0, got {self.fatigue_per_move_step}"
            )
        if self.fatigue_per_fire_step < 0.0:
            raise ValueError(
                f"fatigue_per_fire_step must be >= 0, got {self.fatigue_per_fire_step}"
            )
        if self.fatigue_recovery_per_halt_step < 0.0:
            raise ValueError(
                f"fatigue_recovery_per_halt_step must be >= 0, "
                f"got {self.fatigue_recovery_per_halt_step}"
            )
        if self.resupply_radius <= 0.0:
            raise ValueError(
                f"resupply_radius must be positive, got {self.resupply_radius}"
            )
        if not (0.0 <= self.ammo_resupply_rate <= 1.0):
            raise ValueError(
                f"ammo_resupply_rate must be in [0, 1], got {self.ammo_resupply_rate}"
            )
        if not (0.0 <= self.food_resupply_rate <= 1.0):
            raise ValueError(
                f"food_resupply_rate must be in [0, 1], got {self.food_resupply_rate}"
            )
        if not (0.0 <= self.low_ammo_accuracy_penalty <= 1.0):
            raise ValueError(
                f"low_ammo_accuracy_penalty must be in [0, 1], "
                f"got {self.low_ammo_accuracy_penalty}"
            )
        if not (0.0 <= self.fatigue_speed_penalty <= 1.0):
            raise ValueError(
                f"fatigue_speed_penalty must be in [0, 1], "
                f"got {self.fatigue_speed_penalty}"
            )
        if not (0.0 <= self.fatigue_accuracy_penalty <= 1.0):
            raise ValueError(
                f"fatigue_accuracy_penalty must be in [0, 1], "
                f"got {self.fatigue_accuracy_penalty}"
            )
        if self.wagon_speed <= 0.0:
            raise ValueError(
                f"wagon_speed must be positive, got {self.wagon_speed}"
            )
        if not (0.0 < self.wagon_max_strength <= 1.0):
            raise ValueError(
                f"wagon_max_strength must be in (0, 1], got {self.wagon_max_strength}"
            )

envs.battalion_env.MoraleConfig dataclass

Configurable parameters for the morale state machine.

Pass an instance to :func:update_morale, :class:~envs.sim.engine.SimEngine, or :class:~envs.battalion_env.BattalionEnv to enable the full morale stressor model. All parameters have sensible defaults; tweak them to adjust scenario difficulty or historical calibration.

Parameters:

Name Type Description Default
cohesion_threshold float

Morale below this value causes cohesion loss. The :func:cohesion_modifier function returns a value below 1.0 when morale is beneath this threshold.

DEFAULT_COHESION_THRESHOLD
rout_threshold float

Morale below this value triggers the probabilistic routing check. Must be < cohesion_threshold.

DEFAULT_ROUT_THRESHOLD
base_recovery_rate float

Passive morale recovery per step when the unit is not under fire. Increase to make recovery faster (easier scenario).

0.01
distance_recovery_bonus float

Additional recovery per step, scaled by enemy_dist / safe_distance (capped at 1.0). Rewards moving away from the enemy.

0.01
friendly_support_bonus float

Per-step morale bonus when a friendly unit is within commander_range metres. Rewards cohesive formations.

0.015
commander_proximity_bonus float

Per-step morale bonus when the brigade commander is within commander_range metres. Represents Napoleonic leadership effect.

0.02
commander_range float

Radius in metres within which the commander or a friendly unit must be to grant a proximity bonus.

300.0
rally_threshold_multiplier float

A routing unit can only attempt to rally once its morale has recovered to at least rout_threshold × rally_threshold_multiplier.

2.0
rally_probability float

Probability per step that a unit attempts to rally (given it has reached the rally morale gate).

0.05
rout_speed_multiplier float

Routing units move at max_speed × rout_speed_multiplier.

1.5
safe_distance float

Reference distance (metres) used to normalise the distance recovery bonus. Units at or beyond this distance from the enemy receive the full distance_recovery_bonus.

400.0
Source code in envs/sim/morale.py
@dataclass
class MoraleConfig:
    """Configurable parameters for the morale state machine.

    Pass an instance to :func:`update_morale`, :class:`~envs.sim.engine.SimEngine`,
    or :class:`~envs.battalion_env.BattalionEnv` to enable the full morale
    stressor model.  All parameters have sensible defaults; tweak them to
    adjust scenario difficulty or historical calibration.

    Parameters
    ----------
    cohesion_threshold:
        Morale below this value causes cohesion loss.  The
        :func:`cohesion_modifier` function returns a value below ``1.0``
        when morale is beneath this threshold.
    rout_threshold:
        Morale below this value triggers the probabilistic routing check.
        Must be ``< cohesion_threshold``.
    base_recovery_rate:
        Passive morale recovery per step when the unit is not under fire.
        Increase to make recovery faster (easier scenario).
    distance_recovery_bonus:
        Additional recovery per step, scaled by ``enemy_dist / safe_distance``
        (capped at ``1.0``).  Rewards moving away from the enemy.
    friendly_support_bonus:
        Per-step morale bonus when a friendly unit is within
        ``commander_range`` metres.  Rewards cohesive formations.
    commander_proximity_bonus:
        Per-step morale bonus when the brigade commander is within
        ``commander_range`` metres.  Represents Napoleonic leadership effect.
    commander_range:
        Radius in metres within which the commander or a friendly unit must
        be to grant a proximity bonus.
    rally_threshold_multiplier:
        A routing unit can only attempt to rally once its morale has
        recovered to at least ``rout_threshold × rally_threshold_multiplier``.
    rally_probability:
        Probability per step that a unit attempts to rally (given it has
        reached the rally morale gate).
    rout_speed_multiplier:
        Routing units move at ``max_speed × rout_speed_multiplier``.
    safe_distance:
        Reference distance (metres) used to normalise the distance recovery
        bonus.  Units at or beyond this distance from the enemy receive the
        full ``distance_recovery_bonus``.
    """

    cohesion_threshold: float = DEFAULT_COHESION_THRESHOLD
    rout_threshold: float = DEFAULT_ROUT_THRESHOLD

    # Recovery rates
    base_recovery_rate: float = 0.01
    distance_recovery_bonus: float = 0.01
    friendly_support_bonus: float = 0.015
    commander_proximity_bonus: float = 0.02
    commander_range: float = 300.0

    # Rally mechanics
    rally_threshold_multiplier: float = 2.0
    rally_probability: float = 0.05

    # Rout movement
    rout_speed_multiplier: float = 1.5
    safe_distance: float = 400.0

    def __post_init__(self) -> None:
        if not (0.0 < self.rout_threshold < self.cohesion_threshold <= 1.0):
            raise ValueError(
                f"Thresholds must satisfy 0 < rout_threshold ({self.rout_threshold}) "
                f"< cohesion_threshold ({self.cohesion_threshold}) <= 1.0"
            )
        if self.base_recovery_rate < 0:
            raise ValueError(f"base_recovery_rate must be >= 0, got {self.base_recovery_rate}")
        if self.distance_recovery_bonus < 0:
            raise ValueError(
                f"distance_recovery_bonus must be >= 0, got {self.distance_recovery_bonus}"
            )
        if self.friendly_support_bonus < 0:
            raise ValueError(
                f"friendly_support_bonus must be >= 0, got {self.friendly_support_bonus}"
            )
        if self.commander_proximity_bonus < 0:
            raise ValueError(
                f"commander_proximity_bonus must be >= 0, got {self.commander_proximity_bonus}"
            )
        if self.rally_threshold_multiplier < 0:
            raise ValueError(
                f"rally_threshold_multiplier must be >= 0, got {self.rally_threshold_multiplier}"
            )
        if not (0.0 <= self.rally_probability <= 1.0):
            raise ValueError(
                f"rally_probability must be in [0, 1], got {self.rally_probability}"
            )
        if self.rout_speed_multiplier <= 0:
            raise ValueError(
                f"rout_speed_multiplier must be > 0, got {self.rout_speed_multiplier}"
            )
        if self.safe_distance <= 0:
            raise ValueError(f"safe_distance must be > 0, got {self.safe_distance}")
        if self.commander_range <= 0:
            raise ValueError(f"commander_range must be > 0, got {self.commander_range}")

envs.battalion_env.WeatherConfig dataclass

All tunable parameters for the weather and time-of-day system.

Pass an instance to :class:~envs.battalion_env.BattalionEnv via the weather_config argument to enable the full weather model.

Parameters:

Name Type Description Default
fixed_condition Optional[WeatherCondition]

Force a specific :class:WeatherCondition for every episode. None (the default) selects a condition randomly using condition_weights at each reset().

None
fixed_time_of_day Optional[TimeOfDay]

Force a specific :class:TimeOfDay for every episode. None (the default) draws a random time of day at each reset().

None
steps_per_time_of_day int

Number of simulation steps between time-of-day transitions. 0 (the default) disables progression — the time of day stays fixed for the entire episode. A positive value causes DAWN → DAY → DUSK → NIGHT → DAWN cycling.

0
condition_weights List[float]

Sampling weights for :class:WeatherCondition values [CLEAR, OVERCAST, RAIN, FOG, SNOW]. Values need not sum to 1; they are normalised internally. Ignored when fixed_condition is set.

(lambda: [0.4, 0.25, 0.15, 0.1, 0.1])()
base_visibility_range float

Base visibility range in metres before weather modification (> 0). The combined visibility_fraction is multiplied by this value to produce the effective sight radius in the simulation.

DEFAULT_BASE_VISIBILITY_RANGE
Source code in envs/sim/weather.py
@dataclass
class WeatherConfig:
    """All tunable parameters for the weather and time-of-day system.

    Pass an instance to :class:`~envs.battalion_env.BattalionEnv` via the
    *weather_config* argument to enable the full weather model.

    Parameters
    ----------
    fixed_condition:
        Force a specific :class:`WeatherCondition` for every episode.
        ``None`` (the default) selects a condition randomly using
        *condition_weights* at each ``reset()``.
    fixed_time_of_day:
        Force a specific :class:`TimeOfDay` for every episode.
        ``None`` (the default) draws a random time of day at each ``reset()``.
    steps_per_time_of_day:
        Number of simulation steps between time-of-day transitions.
        ``0`` (the default) disables progression — the time of day stays
        fixed for the entire episode.  A positive value causes DAWN → DAY →
        DUSK → NIGHT → DAWN cycling.
    condition_weights:
        Sampling weights for :class:`WeatherCondition` values
        ``[CLEAR, OVERCAST, RAIN, FOG, SNOW]``.  Values need not sum to 1;
        they are normalised internally.  Ignored when *fixed_condition* is set.
    base_visibility_range:
        Base visibility range in metres before weather modification (> 0).
        The combined ``visibility_fraction`` is multiplied by this value to
        produce the effective sight radius in the simulation.
    """

    fixed_condition: Optional[WeatherCondition] = None
    fixed_time_of_day: Optional[TimeOfDay] = None
    steps_per_time_of_day: int = 0
    condition_weights: List[float] = field(
        default_factory=lambda: [0.40, 0.25, 0.15, 0.10, 0.10]
    )
    base_visibility_range: float = DEFAULT_BASE_VISIBILITY_RANGE

    def __post_init__(self) -> None:
        if self.fixed_condition is not None and not isinstance(
            self.fixed_condition, WeatherCondition
        ):
            raise ValueError(
                f"fixed_condition must be a WeatherCondition instance, "
                f"got {type(self.fixed_condition)!r}"
            )
        if self.fixed_time_of_day is not None and not isinstance(
            self.fixed_time_of_day, TimeOfDay
        ):
            raise ValueError(
                f"fixed_time_of_day must be a TimeOfDay instance, "
                f"got {type(self.fixed_time_of_day)!r}"
            )
        if len(self.condition_weights) != NUM_CONDITIONS:
            raise ValueError(
                f"condition_weights must have {NUM_CONDITIONS} entries "
                f"(one per WeatherCondition), got {len(self.condition_weights)}"
            )
        if any(w < 0.0 for w in self.condition_weights):
            raise ValueError(
                "All condition_weights must be non-negative, "
                f"got {self.condition_weights}"
            )
        if sum(self.condition_weights) <= 0.0:
            raise ValueError(
                "condition_weights must have a positive sum, "
                f"got {self.condition_weights}"
            )
        if self.steps_per_time_of_day < 0:
            raise ValueError(
                f"steps_per_time_of_day must be >= 0, got {self.steps_per_time_of_day}"
            )
        if self.base_visibility_range <= 0.0:
            raise ValueError(
                f"base_visibility_range must be > 0, got {self.base_visibility_range}"
            )

envs.battalion_env.Formation

Bases: IntEnum

Discrete formation states for a Napoleonic infantry battalion.

The integer values are stable indices used in observation vectors and action spaces; do not change them without updating downstream code.

========== === ==================================================== Formation Int Historical role ========== === ==================================================== LINE 0 Two- or three-rank line — maximum firepower front. COLUMN 1 Attack or march column — highest movement speed. SQUARE 2 Hollow square — impenetrable to unsupported cavalry. SKIRMISH 3 Extended order — loose screen, independent aimed fire. ========== === ====================================================

Source code in envs/sim/formations.py
class Formation(IntEnum):
    """Discrete formation states for a Napoleonic infantry battalion.

    The integer values are stable indices used in observation vectors and
    action spaces; do **not** change them without updating downstream code.

    ==========  ===  ====================================================
    Formation   Int  Historical role
    ==========  ===  ====================================================
    LINE        0    Two- or three-rank line — maximum firepower front.
    COLUMN      1    Attack or march column — highest movement speed.
    SQUARE      2    Hollow square — impenetrable to unsupported cavalry.
    SKIRMISH    3    Extended order — loose screen, independent aimed fire.
    ==========  ===  ====================================================
    """

    LINE = 0
    COLUMN = 1
    SQUARE = 2
    SKIRMISH = 3

Simulation primitives

envs.sim.engine.SimEngine

Run a 1v1 battalion episode to completion.

Each :meth:step:

  1. Resets per-step damage accumulators on both :class:CombatState objects.
  2. Computes fire damage from both sides simultaneously (before applying either, so there are no ordering effects on damage calculation).
  3. Applies terrain cover to reduce incoming damage at each unit's position.
  4. Applies casualties and updates strength.
  5. Runs a morale check for each unit, potentially triggering routing.

The episode ends (:meth:is_over returns True) when:

  • Either side's :class:CombatState reports is_routing = True, or
  • Either side's strength falls to :data:DESTROYED_THRESHOLD or below, or
  • step_count reaches max_steps.

Parameters:

Name Type Description Default
blue Battalion

The two opposing battalions. Modified in-place by each step.

required
red Battalion

The two opposing battalions. Modified in-place by each step.

required
terrain Optional[TerrainMap]

Optional :class:~envs.sim.terrain.TerrainMap. Defaults to a flat 1 km × 1 km open plain.

None
max_steps int

Hard cap on episode length (default 500, matching acceptance criterion AC-1 of Epic E1.2).

500
rng Optional[Generator]

Seeded random generator. Defaults to a fresh unseeded generator. Pass a seeded generator for reproducible results.

None
Source code in envs/sim/engine.py
class SimEngine:
    """Run a 1v1 battalion episode to completion.

    Each :meth:`step`:

    1. Resets per-step damage accumulators on both :class:`CombatState` objects.
    2. Computes fire damage from both sides simultaneously (before applying
       either, so there are no ordering effects on damage calculation).
    3. Applies terrain cover to reduce incoming damage at each unit's position.
    4. Applies casualties and updates strength.
    5. Runs a morale check for each unit, potentially triggering routing.

    The episode ends (:meth:`is_over` returns ``True``) when:

    * Either side's :class:`CombatState` reports ``is_routing = True``, or
    * Either side's strength falls to :data:`DESTROYED_THRESHOLD` or below, or
    * ``step_count`` reaches ``max_steps``.

    Parameters
    ----------
    blue, red:
        The two opposing battalions.  Modified in-place by each step.
    terrain:
        Optional :class:`~envs.sim.terrain.TerrainMap`.  Defaults to a
        flat 1 km × 1 km open plain.
    max_steps:
        Hard cap on episode length (default 500, matching acceptance
        criterion AC-1 of Epic E1.2).
    rng:
        Seeded random generator.  Defaults to a fresh unseeded generator.
        Pass a seeded generator for reproducible results.
    """

    def __init__(
        self,
        blue: Battalion,
        red: Battalion,
        terrain: Optional[TerrainMap] = None,
        max_steps: int = 500,
        rng: Optional[np.random.Generator] = None,
        morale_config: Optional[MoraleConfig] = None,
    ) -> None:
        self.blue = blue
        self.red = red
        self.terrain: TerrainMap = terrain if terrain is not None else TerrainMap.flat(1000.0, 1000.0)
        self.max_steps = max_steps
        self.rng: np.random.Generator = rng if rng is not None else np.random.default_rng()
        self.morale_config: Optional[MoraleConfig] = morale_config

        self.blue_state = CombatState()
        self.red_state = CombatState()
        self.step_count: int = 0

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _is_done(self, unit: Battalion, state: CombatState) -> bool:
        """Return ``True`` if *unit* is routed or effectively destroyed."""
        return state.is_routing or unit.strength <= DESTROYED_THRESHOLD

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def is_over(self) -> bool:
        """Return ``True`` when the episode should end."""
        return (
            self._is_done(self.blue, self.blue_state)
            or self._is_done(self.red, self.red_state)
            or self.step_count >= self.max_steps
        )

    def step(self) -> dict:
        """Advance one simulation step.

        Returns
        -------
        dict with keys:
            ``blue_damage_dealt`` – actual damage blue dealt to red this step.
            ``red_damage_dealt``  – actual damage red dealt to blue this step.
            ``blue_routing``      – whether blue is routing after this step.
            ``red_routing``       – whether red is routing after this step.
        """
        # 1. Reset per-step accumulators
        self.blue_state.reset_step_accumulators()
        self.red_state.reset_step_accumulators()

        # 2. Compute raw damages simultaneously (uses each shooter's current
        #    strength so neither side benefits from a favourable firing order)
        raw_blue_to_red = compute_fire_damage(self.blue, self.red, intensity=1.0)
        raw_red_to_blue = compute_fire_damage(self.red, self.blue, intensity=1.0)

        # 3. Apply terrain cover at each *target's* position
        raw_blue_to_red = self.terrain.apply_cover_modifier(
            self.red.x, self.red.y, raw_blue_to_red
        )
        raw_red_to_blue = self.terrain.apply_cover_modifier(
            self.blue.x, self.blue.y, raw_red_to_blue
        )

        # 4. Apply casualties simultaneously
        actual_blue_to_red = apply_casualties(self.red, self.red_state, raw_blue_to_red)
        actual_red_to_blue = apply_casualties(self.blue, self.blue_state, raw_red_to_blue)

        # Track shots fired only when a volley actually occurs
        if raw_blue_to_red > 0.0:
            self.blue_state.shots_fired += 1
        if raw_red_to_blue > 0.0:
            self.red_state.shots_fired += 1

        # Compute enemy distance for distance-based morale recovery
        dx = self.blue.x - self.red.x
        dy = self.blue.y - self.red.y
        enemy_dist = float(np.sqrt(dx * dx + dy * dy))

        # 5. Morale checks — use enhanced update_morale when a MoraleConfig is
        #    provided, otherwise fall back to the basic morale_check.
        if self.morale_config is not None:
            mc = self.morale_config
            blue_flank = compute_flank_stressor(
                self.red.x, self.red.y,
                self.blue.x, self.blue.y, self.blue.theta,
                actual_red_to_blue,
            )
            red_flank = compute_flank_stressor(
                self.blue.x, self.blue.y,
                self.red.x, self.red.y, self.red.theta,
                actual_blue_to_red,
            )
            blue_routing = update_morale(
                self.blue_state,
                enemy_dist=enemy_dist,
                config=mc,
                flank_penalty=blue_flank,
                rng=self.rng,
            )
            red_routing = update_morale(
                self.red_state,
                enemy_dist=enemy_dist,
                config=mc,
                flank_penalty=red_flank,
                rng=self.rng,
            )
        else:
            blue_routing = morale_check(self.blue_state, rng=self.rng)
            red_routing = morale_check(self.red_state, rng=self.rng)

        # 6. Apply forced rout movement (overrides normal movement)
        if self.morale_config is not None:
            if blue_routing:
                vx, vy = rout_velocity(
                    self.blue.x, self.blue.y,
                    self.red.x, self.red.y,
                    self.blue.max_speed,
                    self.morale_config,
                )
                self.blue.move(vx, vy, dt=DT)
            if red_routing:
                vx, vy = rout_velocity(
                    self.red.x, self.red.y,
                    self.blue.x, self.blue.y,
                    self.red.max_speed,
                    self.morale_config,
                )
                self.red.move(vx, vy, dt=DT)

        # Keep Battalion.morale and Battalion.routed in sync with CombatState
        self.blue.morale = self.blue_state.morale
        self.red.morale = self.red_state.morale
        self.blue.routed = self.blue_state.is_routing
        self.red.routed = self.red_state.is_routing

        self.step_count += 1

        return {
            "blue_damage_dealt": actual_blue_to_red,
            "red_damage_dealt": actual_red_to_blue,
            "blue_routing": blue_routing,
            "red_routing": red_routing,
        }

    def run(self) -> EpisodeResult:
        """Run the episode to completion and return a result summary."""
        while not self.is_over():
            self.step()
        return self._make_result()

    def _make_result(self) -> EpisodeResult:
        blue_done = self._is_done(self.blue, self.blue_state)
        red_done = self._is_done(self.red, self.red_state)

        if blue_done and not red_done:
            winner: int | None = 1  # red wins
        elif red_done and not blue_done:
            winner = 0  # blue wins
        else:
            winner = None  # draw: simultaneous rout/destruction or time-out

        return EpisodeResult(
            winner=winner,
            steps=self.step_count,
            blue_strength=self.blue.strength,
            red_strength=self.red.strength,
            blue_morale=self.blue_state.morale,
            red_morale=self.red_state.morale,
            blue_routed=self.blue_state.is_routing,
            red_routed=self.red_state.is_routing,
        )

is_over()

Return True when the episode should end.

Source code in envs/sim/engine.py
def is_over(self) -> bool:
    """Return ``True`` when the episode should end."""
    return (
        self._is_done(self.blue, self.blue_state)
        or self._is_done(self.red, self.red_state)
        or self.step_count >= self.max_steps
    )

run()

Run the episode to completion and return a result summary.

Source code in envs/sim/engine.py
def run(self) -> EpisodeResult:
    """Run the episode to completion and return a result summary."""
    while not self.is_over():
        self.step()
    return self._make_result()

step()

Advance one simulation step.

Returns:

Type Description
dict with keys:

blue_damage_dealt – actual damage blue dealt to red this step. red_damage_dealt – actual damage red dealt to blue this step. blue_routing – whether blue is routing after this step. red_routing – whether red is routing after this step.

Source code in envs/sim/engine.py
def step(self) -> dict:
    """Advance one simulation step.

    Returns
    -------
    dict with keys:
        ``blue_damage_dealt`` – actual damage blue dealt to red this step.
        ``red_damage_dealt``  – actual damage red dealt to blue this step.
        ``blue_routing``      – whether blue is routing after this step.
        ``red_routing``       – whether red is routing after this step.
    """
    # 1. Reset per-step accumulators
    self.blue_state.reset_step_accumulators()
    self.red_state.reset_step_accumulators()

    # 2. Compute raw damages simultaneously (uses each shooter's current
    #    strength so neither side benefits from a favourable firing order)
    raw_blue_to_red = compute_fire_damage(self.blue, self.red, intensity=1.0)
    raw_red_to_blue = compute_fire_damage(self.red, self.blue, intensity=1.0)

    # 3. Apply terrain cover at each *target's* position
    raw_blue_to_red = self.terrain.apply_cover_modifier(
        self.red.x, self.red.y, raw_blue_to_red
    )
    raw_red_to_blue = self.terrain.apply_cover_modifier(
        self.blue.x, self.blue.y, raw_red_to_blue
    )

    # 4. Apply casualties simultaneously
    actual_blue_to_red = apply_casualties(self.red, self.red_state, raw_blue_to_red)
    actual_red_to_blue = apply_casualties(self.blue, self.blue_state, raw_red_to_blue)

    # Track shots fired only when a volley actually occurs
    if raw_blue_to_red > 0.0:
        self.blue_state.shots_fired += 1
    if raw_red_to_blue > 0.0:
        self.red_state.shots_fired += 1

    # Compute enemy distance for distance-based morale recovery
    dx = self.blue.x - self.red.x
    dy = self.blue.y - self.red.y
    enemy_dist = float(np.sqrt(dx * dx + dy * dy))

    # 5. Morale checks — use enhanced update_morale when a MoraleConfig is
    #    provided, otherwise fall back to the basic morale_check.
    if self.morale_config is not None:
        mc = self.morale_config
        blue_flank = compute_flank_stressor(
            self.red.x, self.red.y,
            self.blue.x, self.blue.y, self.blue.theta,
            actual_red_to_blue,
        )
        red_flank = compute_flank_stressor(
            self.blue.x, self.blue.y,
            self.red.x, self.red.y, self.red.theta,
            actual_blue_to_red,
        )
        blue_routing = update_morale(
            self.blue_state,
            enemy_dist=enemy_dist,
            config=mc,
            flank_penalty=blue_flank,
            rng=self.rng,
        )
        red_routing = update_morale(
            self.red_state,
            enemy_dist=enemy_dist,
            config=mc,
            flank_penalty=red_flank,
            rng=self.rng,
        )
    else:
        blue_routing = morale_check(self.blue_state, rng=self.rng)
        red_routing = morale_check(self.red_state, rng=self.rng)

    # 6. Apply forced rout movement (overrides normal movement)
    if self.morale_config is not None:
        if blue_routing:
            vx, vy = rout_velocity(
                self.blue.x, self.blue.y,
                self.red.x, self.red.y,
                self.blue.max_speed,
                self.morale_config,
            )
            self.blue.move(vx, vy, dt=DT)
        if red_routing:
            vx, vy = rout_velocity(
                self.red.x, self.red.y,
                self.blue.x, self.blue.y,
                self.red.max_speed,
                self.morale_config,
            )
            self.red.move(vx, vy, dt=DT)

    # Keep Battalion.morale and Battalion.routed in sync with CombatState
    self.blue.morale = self.blue_state.morale
    self.red.morale = self.red_state.morale
    self.blue.routed = self.blue_state.is_routing
    self.red.routed = self.red_state.is_routing

    self.step_count += 1

    return {
        "blue_damage_dealt": actual_blue_to_red,
        "red_damage_dealt": actual_red_to_blue,
        "blue_routing": blue_routing,
        "red_routing": red_routing,
    }

envs.sim.engine.EpisodeResult dataclass

Summary returned by :meth:SimEngine.run.

Attributes:

Name Type Description
winner int | None

0 if blue wins, 1 if red wins, None for a draw (both sides ended simultaneously, or max steps was reached).

steps int

Number of simulation steps taken.

blue_strength, red_strength

Final strength values in [0, 1].

blue_morale, red_morale

Final morale values in [0, 1].

blue_routed, red_routed

Whether each side is routing at episode end.

Source code in envs/sim/engine.py
@dataclass
class EpisodeResult:
    """Summary returned by :meth:`SimEngine.run`.

    Attributes
    ----------
    winner:
        ``0`` if blue wins, ``1`` if red wins, ``None`` for a draw (both
        sides ended simultaneously, or max steps was reached).
    steps:
        Number of simulation steps taken.
    blue_strength, red_strength:
        Final strength values in ``[0, 1]``.
    blue_morale, red_morale:
        Final morale values in ``[0, 1]``.
    blue_routed, red_routed:
        Whether each side is routing at episode end.
    """

    winner: int | None
    steps: int
    blue_strength: float
    red_strength: float
    blue_morale: float
    red_morale: float
    blue_routed: bool
    red_routed: bool

HRL options framework

envs.options.MacroAction

Bases: IntEnum

Indices for the six standard macro-actions in the SMDP vocabulary.

Source code in envs/options.py
class MacroAction(IntEnum):
    """Indices for the six standard macro-actions in the SMDP vocabulary."""

    ADVANCE_SECTOR = 0
    DEFEND_POSITION = 1
    FLANK_LEFT = 2
    FLANK_RIGHT = 3
    WITHDRAW = 4
    CONCENTRATE_FIRE = 5

envs.options.Option dataclass

An SMDP Option — a temporally-extended macro-action.

Parameters:

Name Type Description Default
name str

Human-readable label used in logging and debugging.

required
initiation_set Callable[[ndarray], bool]

Callable (obs: np.ndarray) -> bool. Returns True if this option may be initiated from the current local observation.

required
policy Callable[[ndarray], ndarray]

Callable (obs: np.ndarray) -> np.ndarray. Maps the agent's current local observation to a primitive action of shape (3,): [move ∈ [-1,1], rotate ∈ [-1,1], fire ∈ [0,1]].

required
termination Callable[[ndarray, int], bool]

Callable (obs: np.ndarray, steps_active: int) -> bool. Returns True when the option should end. steps_active counts how many primitive steps this option has been running (≥ 1 on first call after the initial step).

required
max_steps int

Hard cap on option duration in primitive steps. The option is forced to terminate once steps_active >= max_steps regardless of the termination callable.

50
Source code in envs/options.py
@dataclass
class Option:
    """An SMDP Option — a temporally-extended macro-action.

    Parameters
    ----------
    name:
        Human-readable label used in logging and debugging.
    initiation_set:
        Callable ``(obs: np.ndarray) -> bool``.  Returns ``True`` if this
        option may be initiated from the current local observation.
    policy:
        Callable ``(obs: np.ndarray) -> np.ndarray``.  Maps the agent's
        current local observation to a primitive action of shape ``(3,)``:
        ``[move ∈ [-1,1], rotate ∈ [-1,1], fire ∈ [0,1]]``.
    termination:
        Callable ``(obs: np.ndarray, steps_active: int) -> bool``.  Returns
        ``True`` when the option should end.  ``steps_active`` counts how
        many primitive steps this option has been running (≥ 1 on first
        call after the initial step).
    max_steps:
        Hard cap on option duration in primitive steps.  The option is
        forced to terminate once ``steps_active >= max_steps`` regardless
        of the ``termination`` callable.
    """

    name: str
    initiation_set: Callable[[np.ndarray], bool]
    policy: Callable[[np.ndarray], np.ndarray]
    termination: Callable[[np.ndarray, int], bool]
    max_steps: int = 50

    def can_initiate(self, obs: np.ndarray) -> bool:
        """Return ``True`` if this option can be initiated from *obs*."""
        return bool(self.initiation_set(obs))

    def get_action(self, obs: np.ndarray) -> np.ndarray:
        """Return a primitive action from this option's policy for *obs*.

        The returned action is validated to ensure it matches the expected
        shape ``(3,)`` and contains only finite values.  This provides
        earlier and more informative errors than downstream index failures
        in the environment step logic.

        Raises
        ------
        ValueError
            If the policy returns an action with shape other than ``(3,)``
            or containing non-finite values (NaN or Inf).
        """
        action = np.asarray(self.policy(obs), dtype=np.float32)

        if action.shape != (3,):
            raise ValueError(
                f"Option '{self.name}' policy returned action with shape "
                f"{action.shape!r}, but expected shape (3,)."
            )

        if not np.all(np.isfinite(action)):
            raise ValueError(
                f"Option '{self.name}' policy returned non-finite action "
                f"values: {action!r}"
            )

        return action

    def should_terminate(self, obs: np.ndarray, steps_active: int) -> bool:
        """Return ``True`` if this option should terminate.

        Terminates when the hard cap ``max_steps`` is reached **or** the
        caller-supplied ``termination`` callable returns ``True``.
        """
        if steps_active >= self.max_steps:
            return True
        return bool(self.termination(obs, steps_active))

can_initiate(obs)

Return True if this option can be initiated from obs.

Source code in envs/options.py
def can_initiate(self, obs: np.ndarray) -> bool:
    """Return ``True`` if this option can be initiated from *obs*."""
    return bool(self.initiation_set(obs))

get_action(obs)

Return a primitive action from this option's policy for obs.

The returned action is validated to ensure it matches the expected shape (3,) and contains only finite values. This provides earlier and more informative errors than downstream index failures in the environment step logic.

Raises:

Type Description
ValueError

If the policy returns an action with shape other than (3,) or containing non-finite values (NaN or Inf).

Source code in envs/options.py
def get_action(self, obs: np.ndarray) -> np.ndarray:
    """Return a primitive action from this option's policy for *obs*.

    The returned action is validated to ensure it matches the expected
    shape ``(3,)`` and contains only finite values.  This provides
    earlier and more informative errors than downstream index failures
    in the environment step logic.

    Raises
    ------
    ValueError
        If the policy returns an action with shape other than ``(3,)``
        or containing non-finite values (NaN or Inf).
    """
    action = np.asarray(self.policy(obs), dtype=np.float32)

    if action.shape != (3,):
        raise ValueError(
            f"Option '{self.name}' policy returned action with shape "
            f"{action.shape!r}, but expected shape (3,)."
        )

    if not np.all(np.isfinite(action)):
        raise ValueError(
            f"Option '{self.name}' policy returned non-finite action "
            f"values: {action!r}"
        )

    return action

should_terminate(obs, steps_active)

Return True if this option should terminate.

Terminates when the hard cap max_steps is reached or the caller-supplied termination callable returns True.

Source code in envs/options.py
def should_terminate(self, obs: np.ndarray, steps_active: int) -> bool:
    """Return ``True`` if this option should terminate.

    Terminates when the hard cap ``max_steps`` is reached **or** the
    caller-supplied ``termination`` callable returns ``True``.
    """
    if steps_active >= self.max_steps:
        return True
    return bool(self.termination(obs, steps_active))

envs.options.make_default_options(max_steps=30)

Return the six-element default macro-action vocabulary.

All option policies use only the agent's self-state (obs[0:6]) which occupies fixed positions in every observation regardless of team size. This makes the options environment-agnostic.

Parameters:

Name Type Description Default
max_steps int

Maximum primitive steps per option execution (time-limit termination). Flanking options use max_steps // 2 as their cap to keep them short and decisive.

30

Returns:

Type Description
list[Option]

Six Option objects in :class:MacroAction index order: [advance_sector, defend_position, flank_left, flank_right, withdraw, concentrate_fire].

Source code in envs/options.py
def make_default_options(max_steps: int = 30) -> list[Option]:
    """Return the six-element default macro-action vocabulary.

    All option policies use only the agent's **self-state** (``obs[0:6]``)
    which occupies fixed positions in every observation regardless of team
    size.  This makes the options environment-agnostic.

    Parameters
    ----------
    max_steps:
        Maximum primitive steps per option execution (time-limit
        termination).  Flanking options use ``max_steps // 2`` as their
        cap to keep them short and decisive.

    Returns
    -------
    list[Option]
        Six ``Option`` objects in :class:`MacroAction` index order:
        ``[advance_sector, defend_position, flank_left, flank_right,
        withdraw, concentrate_fire]``.
    """
    if max_steps < 1:
        raise ValueError(
            f"max_steps must be >= 1, got {max_steps}. "
            "Non-positive values would create options that terminate immediately "
            "and make flanking durations inconsistent with the documented behaviour."
        )

    # ------------------------------------------------------------------
    # Shared termination predicates
    # ------------------------------------------------------------------

    def _routing(obs: np.ndarray, _steps: int) -> bool:
        """Unit is routing (morale below rout threshold)."""
        return float(obs[OBS_MORALE]) < _ROUT_THRESHOLD

    def _low_strength(obs: np.ndarray, _steps: int) -> bool:
        """Unit has taken significant casualties."""
        return float(obs[OBS_STRENGTH]) < _LOW_STRENGTH

    def _recovered(obs: np.ndarray, _steps: int) -> bool:
        """Unit morale has recovered enough to stop withdrawing."""
        return float(obs[OBS_MORALE]) > _SAFE_MORALE

    # Flanking options run for half the normal duration
    flank_max = max(1, max_steps // 2)

    # ------------------------------------------------------------------
    # ADVANCE_SECTOR (index 0)
    # Move forward aggressively with suppression fire.
    # Distinct pattern: high forward speed + moderate fire.
    # ------------------------------------------------------------------
    advance = Option(
        name="advance_sector",
        initiation_set=lambda obs: float(obs[OBS_MORALE]) >= _LOW_MORALE,
        policy=lambda obs: np.array([_ADVANCE_SPEED, 0.0, _SUPPRESSION_FIRE], dtype=np.float32),
        termination=lambda obs, s: _routing(obs, s) or _low_strength(obs, s),
        max_steps=max_steps,
    )

    # ------------------------------------------------------------------
    # DEFEND_POSITION (index 1)
    # Hold ground with maximum sustained fire, no movement.
    # Distinct pattern: zero movement + full fire.
    # ------------------------------------------------------------------
    defend = Option(
        name="defend_position",
        initiation_set=lambda obs: True,
        policy=lambda obs: np.array([0.0, 0.0, 1.0], dtype=np.float32),
        termination=_routing,
        max_steps=max_steps,
    )

    # ------------------------------------------------------------------
    # FLANK_LEFT (index 2)
    # Move while rotating counter-clockwise to attack the enemy's flank.
    # Distinct pattern: moderate forward + full left rotation + no fire.
    # ------------------------------------------------------------------
    flank_left = Option(
        name="flank_left",
        initiation_set=lambda obs: float(obs[OBS_MORALE]) >= _LOW_MORALE,
        policy=lambda obs: np.array([_FLANK_SPEED, -1.0, 0.0], dtype=np.float32),
        termination=_routing,
        max_steps=flank_max,
    )

    # ------------------------------------------------------------------
    # FLANK_RIGHT (index 3)
    # Move while rotating clockwise to attack the enemy's flank.
    # Distinct pattern: moderate forward + full right rotation + no fire.
    # ------------------------------------------------------------------
    flank_right = Option(
        name="flank_right",
        initiation_set=lambda obs: float(obs[OBS_MORALE]) >= _LOW_MORALE,
        policy=lambda obs: np.array([_FLANK_SPEED, 1.0, 0.0], dtype=np.float32),
        termination=_routing,
        max_steps=flank_max,
    )

    # ------------------------------------------------------------------
    # WITHDRAW (index 4)
    # Retreat at full speed away from enemies; no fire.
    # Initiation gated on low morale or low strength.
    # Distinct pattern: full backward speed + no fire.
    # ------------------------------------------------------------------
    withdraw = Option(
        name="withdraw",
        initiation_set=lambda obs: (
            float(obs[OBS_MORALE]) < _LOW_MORALE
            or float(obs[OBS_STRENGTH]) < _LOW_STRENGTH
        ),
        policy=lambda obs: np.array([-1.0, 0.0, 0.0], dtype=np.float32),
        termination=_recovered,
        max_steps=max_steps,
    )

    # ------------------------------------------------------------------
    # CONCENTRATE_FIRE (index 5)
    # Stationary with maximum sustained fire and slight tracking rotation.
    # Distinct pattern: zero movement + tracking rotation + full fire.
    # ------------------------------------------------------------------
    concentrate = Option(
        name="concentrate_fire",
        initiation_set=lambda obs: True,
        policy=lambda obs: np.array([0.0, _TRACKING_ROTATE, 1.0], dtype=np.float32),
        termination=lambda obs, s: _routing(obs, s) or _low_strength(obs, s),
        max_steps=max_steps,
    )

    return [advance, defend, flank_left, flank_right, withdraw, concentrate]

envs.smdp_wrapper.SMDPWrapper

Bases: ParallelEnv

PettingZoo ParallelEnv wrapping MultiBattalionEnv with SMDP options.

Parameters:

Name Type Description Default
env MultiBattalionEnv

The underlying :class:~envs.multi_battalion_env.MultiBattalionEnv instance.

required
options Optional[list[Option]]

Option vocabulary — a list of :class:~envs.options.Option objects. Index i corresponds to macro-action i. When None, :func:~envs.options.make_default_options is used to build the standard six-element vocabulary.

None
Source code in envs/smdp_wrapper.py
class SMDPWrapper(ParallelEnv):
    """PettingZoo ParallelEnv wrapping MultiBattalionEnv with SMDP options.

    Parameters
    ----------
    env:
        The underlying :class:`~envs.multi_battalion_env.MultiBattalionEnv`
        instance.
    options:
        Option vocabulary — a list of :class:`~envs.options.Option` objects.
        Index ``i`` corresponds to macro-action ``i``.  When ``None``,
        :func:`~envs.options.make_default_options` is used to build the
        standard six-element vocabulary.
    """

    metadata: dict = {"render_modes": [], "name": "smdp_multi_battalion_v0"}

    def __init__(
        self,
        env: MultiBattalionEnv,
        options: Optional[list[Option]] = None,
    ) -> None:
        if options is None:
            options = make_default_options()

        if len(options) == 0:
            raise ValueError(
                "options must contain at least one Option; received an empty list."
            )

        self._env = env
        self._options: list[Option] = list(options)
        self.n_options: int = len(self._options)

        # ------------------------------------------------------------------
        # PettingZoo required attributes
        # ------------------------------------------------------------------
        self.possible_agents: list[str] = list(env.possible_agents)
        self.agents: list[str] = []
        self.render_mode: Optional[str] = getattr(env, "render_mode", None)

        # ------------------------------------------------------------------
        # Spaces
        # Observation space is the same as the underlying environment;
        # action space is Discrete(n_options).
        # ------------------------------------------------------------------
        self._obs_spaces: dict[str, spaces.Space] = {
            a: env.observation_space(a) for a in self.possible_agents
        }
        self._act_space: spaces.Discrete = spaces.Discrete(self.n_options)

        # ------------------------------------------------------------------
        # Episode-level counters
        # ------------------------------------------------------------------
        self._macro_steps: int = 0
        self._primitive_steps: int = 0

        # Last observations keyed by agent (populated by reset / step)
        self._last_obs: dict[str, np.ndarray] = {}

    # ------------------------------------------------------------------
    # PettingZoo API: spaces
    # ------------------------------------------------------------------

    def observation_space(self, agent: str) -> spaces.Space:
        """Return the observation space for *agent* (same as underlying env)."""
        return self._obs_spaces[agent]

    def action_space(self, agent: str) -> spaces.Discrete:
        """Return the macro-action space for *agent* — ``Discrete(n_options)``."""
        return self._act_space

    # ------------------------------------------------------------------
    # Temporal abstraction property
    # ------------------------------------------------------------------

    @property
    def temporal_abstraction_ratio(self) -> float:
        """Macro-steps / primitive-steps for the current episode.

        Returns ``0.0`` before the first macro-step completes.
        """
        if self._primitive_steps == 0:
            return 0.0
        return self._macro_steps / self._primitive_steps

    # ------------------------------------------------------------------
    # PettingZoo API: reset
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: Optional[int] = None,
        options: Optional[dict] = None,
    ) -> tuple[dict[str, np.ndarray], dict[str, dict]]:
        """Reset the environment and return initial observations.

        Parameters
        ----------
        seed:
            RNG seed forwarded to the underlying environment.
        options:
            Currently unused; accepted for API compatibility.

        Returns
        -------
        observations : dict[agent_id, np.ndarray]
        infos        : dict[agent_id, dict]
        """
        obs, infos = self._env.reset(seed=seed, options=options)
        self.agents = list(self._env.agents)
        self._macro_steps = 0
        self._primitive_steps = 0
        self._last_obs = {a: o.copy() for a, o in obs.items()}
        return obs, infos

    # ------------------------------------------------------------------
    # PettingZoo API: step
    # ------------------------------------------------------------------

    def step(
        self,
        macro_actions: dict[str, int],
    ) -> tuple[
        dict[str, np.ndarray],
        dict[str, float],
        dict[str, bool],
        dict[str, bool],
        dict[str, dict],
    ]:
        """Execute one macro-step.

        Runs each agent's chosen option for multiple primitive steps until
        **all** active options have terminated (option termination condition
        fires, hard time-limit reached, or the underlying env terminates the
        agent).

        Parameters
        ----------
        macro_actions:
            Dict mapping each live agent ID to a macro-action index
            ``[0, n_options)``.  Missing agents default to index ``0``
            (``advance_sector``).

        Returns
        -------
        observations, aggregate_rewards, terminated, truncated, infos
            Keyed by agent IDs that were alive at the *start* of this
            macro-step, matching the PettingZoo convention.
        """
        if not self.agents:
            return {}, {}, {}, {}, {}

        current_agents: list[str] = list(self.agents)

        # ------------------------------------------------------------------
        # Initialise per-agent tracking for this macro-step
        # ------------------------------------------------------------------
        selected_options: dict[str, Option] = {}
        option_steps: dict[str, int] = {}
        aggregate_rewards: dict[str, float] = {a: 0.0 for a in current_agents}
        option_done: dict[str, bool] = {a: False for a in current_agents}

        for agent in current_agents:
            # Missing agents default to option index 0 (e.g., "advance_sector").
            if agent not in macro_actions:
                idx = 0
            else:
                raw_idx = macro_actions[agent]
                idx = int(raw_idx)
                if idx < 0 or idx >= self.n_options:
                    raise ValueError(
                        f"Invalid macro-action index {idx!r} for agent {agent!r}; "
                        f"expected integer in [0, {self.n_options - 1}] or omit "
                        f"the agent key to use the default option 0."
                    )
            selected_options[agent] = self._options[idx]
            option_steps[agent] = 0

        # Initialise final output dicts with safe defaults
        final_obs: dict[str, np.ndarray] = {
            a: self._last_obs[a].copy() for a in current_agents
        }
        final_terminated: dict[str, bool] = {a: False for a in current_agents}
        final_truncated: dict[str, bool] = {a: False for a in current_agents}
        final_infos: dict[str, dict] = {a: {} for a in current_agents}

        # ------------------------------------------------------------------
        # Inner primitive-step loop
        # Run until all current_agents have finished their options OR
        # the underlying environment has no more live agents.
        # ------------------------------------------------------------------
        while any(not option_done[a] for a in current_agents):
            # Exit if the underlying env has exhausted all agents
            if not self._env.agents:
                for agent in current_agents:
                    option_done[agent] = True
                break

            # Build primitive actions for all env-live agents
            prim_actions: dict[str, np.ndarray] = {}
            for agent in self._env.agents:
                if agent in current_agents and not option_done[agent]:
                    prim_actions[agent] = selected_options[agent].get_action(
                        self._last_obs[agent]
                    )
                else:
                    # No-op for agents whose option has already terminated
                    prim_actions[agent] = np.zeros(_PRIM_ACTION_DIM, dtype=np.float32)

            # Primitive step
            obs, rewards, terminated, truncated, infos = self._env.step(prim_actions)
            self._primitive_steps += 1

            # ----------------------------------------------------------------
            # Process results for each agent alive at macro-step start
            # ----------------------------------------------------------------
            for agent in current_agents:
                if option_done[agent]:
                    continue

                # Accumulate reward
                if agent in rewards:
                    aggregate_rewards[agent] += float(rewards[agent])

                # Update latest observation
                if agent in obs:
                    self._last_obs[agent] = obs[agent].copy()
                    final_obs[agent] = obs[agent].copy()

                # Update info (keep the latest primitive info)
                if agent in infos:
                    final_infos[agent] = dict(infos[agent])

                # Check underlying-env termination first
                env_terminated = terminated.get(agent, False)
                env_truncated = truncated.get(agent, False)
                if env_terminated or env_truncated:
                    final_terminated[agent] = bool(env_terminated)
                    final_truncated[agent] = bool(env_truncated)
                    option_done[agent] = True
                else:
                    # Increment option step counter and check option termination
                    option_steps[agent] += 1
                    if selected_options[agent].should_terminate(
                        self._last_obs[agent], option_steps[agent]
                    ):
                        option_done[agent] = True

        # ------------------------------------------------------------------
        # Post-loop: handle agents that the env removed without explicit
        # terminated/truncated signals (edge case when the underlying env
        # empties mid-option and the agent was never seen in a terminated
        # dict during this macro-step).  Treating them as truncated preserves
        # PettingZoo semantics — the agent was alive at macro-step start but
        # absent at the end because the episode was cut short by the env.
        # ------------------------------------------------------------------
        for agent in current_agents:
            if (
                agent not in self._env.agents
                and not final_terminated[agent]
                and not final_truncated[agent]
            ):
                final_truncated[agent] = True

        self._macro_steps += 1

        # Update wrapper-level agent list from underlying env
        self.agents = list(self._env.agents)

        # ------------------------------------------------------------------
        # Attach temporal abstraction metadata to every agent's info dict
        # ------------------------------------------------------------------
        ta_ratio = self.temporal_abstraction_ratio
        for agent in current_agents:
            final_infos[agent]["temporal_abstraction"] = {
                "macro_steps": self._macro_steps,
                "primitive_steps": self._primitive_steps,
                "ratio": ta_ratio,
                "option_name": selected_options[agent].name,
                "option_steps": option_steps.get(agent, 0),
            }

        return final_obs, aggregate_rewards, final_terminated, final_truncated, final_infos

    # ------------------------------------------------------------------
    # PettingZoo API: state (delegated)
    # ------------------------------------------------------------------

    def state(self) -> np.ndarray:
        """Return the global state tensor from the underlying environment."""
        return self._env.state()

    # ------------------------------------------------------------------
    # PettingZoo API: render / close (delegated)
    # ------------------------------------------------------------------

    def render(self) -> None:
        """Delegate rendering to the underlying environment."""
        return self._env.render()

    def close(self) -> None:
        """Delegate resource cleanup to the underlying environment."""
        return self._env.close()

temporal_abstraction_ratio property

Macro-steps / primitive-steps for the current episode.

Returns 0.0 before the first macro-step completes.

action_space(agent)

Return the macro-action space for agentDiscrete(n_options).

Source code in envs/smdp_wrapper.py
def action_space(self, agent: str) -> spaces.Discrete:
    """Return the macro-action space for *agent* — ``Discrete(n_options)``."""
    return self._act_space

close()

Delegate resource cleanup to the underlying environment.

Source code in envs/smdp_wrapper.py
def close(self) -> None:
    """Delegate resource cleanup to the underlying environment."""
    return self._env.close()

observation_space(agent)

Return the observation space for agent (same as underlying env).

Source code in envs/smdp_wrapper.py
def observation_space(self, agent: str) -> spaces.Space:
    """Return the observation space for *agent* (same as underlying env)."""
    return self._obs_spaces[agent]

render()

Delegate rendering to the underlying environment.

Source code in envs/smdp_wrapper.py
def render(self) -> None:
    """Delegate rendering to the underlying environment."""
    return self._env.render()

reset(seed=None, options=None)

Reset the environment and return initial observations.

Parameters:

Name Type Description Default
seed Optional[int]

RNG seed forwarded to the underlying environment.

None
options Optional[dict]

Currently unused; accepted for API compatibility.

None

Returns:

Name Type Description
observations dict[agent_id, ndarray]
infos dict[agent_id, dict]
Source code in envs/smdp_wrapper.py
def reset(
    self,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> tuple[dict[str, np.ndarray], dict[str, dict]]:
    """Reset the environment and return initial observations.

    Parameters
    ----------
    seed:
        RNG seed forwarded to the underlying environment.
    options:
        Currently unused; accepted for API compatibility.

    Returns
    -------
    observations : dict[agent_id, np.ndarray]
    infos        : dict[agent_id, dict]
    """
    obs, infos = self._env.reset(seed=seed, options=options)
    self.agents = list(self._env.agents)
    self._macro_steps = 0
    self._primitive_steps = 0
    self._last_obs = {a: o.copy() for a, o in obs.items()}
    return obs, infos

state()

Return the global state tensor from the underlying environment.

Source code in envs/smdp_wrapper.py
def state(self) -> np.ndarray:
    """Return the global state tensor from the underlying environment."""
    return self._env.state()

step(macro_actions)

Execute one macro-step.

Runs each agent's chosen option for multiple primitive steps until all active options have terminated (option termination condition fires, hard time-limit reached, or the underlying env terminates the agent).

Parameters:

Name Type Description Default
macro_actions dict[str, int]

Dict mapping each live agent ID to a macro-action index [0, n_options). Missing agents default to index 0 (advance_sector).

required

Returns:

Type Description
(observations, aggregate_rewards, terminated, truncated, infos)

Keyed by agent IDs that were alive at the start of this macro-step, matching the PettingZoo convention.

Source code in envs/smdp_wrapper.py
def step(
    self,
    macro_actions: dict[str, int],
) -> tuple[
    dict[str, np.ndarray],
    dict[str, float],
    dict[str, bool],
    dict[str, bool],
    dict[str, dict],
]:
    """Execute one macro-step.

    Runs each agent's chosen option for multiple primitive steps until
    **all** active options have terminated (option termination condition
    fires, hard time-limit reached, or the underlying env terminates the
    agent).

    Parameters
    ----------
    macro_actions:
        Dict mapping each live agent ID to a macro-action index
        ``[0, n_options)``.  Missing agents default to index ``0``
        (``advance_sector``).

    Returns
    -------
    observations, aggregate_rewards, terminated, truncated, infos
        Keyed by agent IDs that were alive at the *start* of this
        macro-step, matching the PettingZoo convention.
    """
    if not self.agents:
        return {}, {}, {}, {}, {}

    current_agents: list[str] = list(self.agents)

    # ------------------------------------------------------------------
    # Initialise per-agent tracking for this macro-step
    # ------------------------------------------------------------------
    selected_options: dict[str, Option] = {}
    option_steps: dict[str, int] = {}
    aggregate_rewards: dict[str, float] = {a: 0.0 for a in current_agents}
    option_done: dict[str, bool] = {a: False for a in current_agents}

    for agent in current_agents:
        # Missing agents default to option index 0 (e.g., "advance_sector").
        if agent not in macro_actions:
            idx = 0
        else:
            raw_idx = macro_actions[agent]
            idx = int(raw_idx)
            if idx < 0 or idx >= self.n_options:
                raise ValueError(
                    f"Invalid macro-action index {idx!r} for agent {agent!r}; "
                    f"expected integer in [0, {self.n_options - 1}] or omit "
                    f"the agent key to use the default option 0."
                )
        selected_options[agent] = self._options[idx]
        option_steps[agent] = 0

    # Initialise final output dicts with safe defaults
    final_obs: dict[str, np.ndarray] = {
        a: self._last_obs[a].copy() for a in current_agents
    }
    final_terminated: dict[str, bool] = {a: False for a in current_agents}
    final_truncated: dict[str, bool] = {a: False for a in current_agents}
    final_infos: dict[str, dict] = {a: {} for a in current_agents}

    # ------------------------------------------------------------------
    # Inner primitive-step loop
    # Run until all current_agents have finished their options OR
    # the underlying environment has no more live agents.
    # ------------------------------------------------------------------
    while any(not option_done[a] for a in current_agents):
        # Exit if the underlying env has exhausted all agents
        if not self._env.agents:
            for agent in current_agents:
                option_done[agent] = True
            break

        # Build primitive actions for all env-live agents
        prim_actions: dict[str, np.ndarray] = {}
        for agent in self._env.agents:
            if agent in current_agents and not option_done[agent]:
                prim_actions[agent] = selected_options[agent].get_action(
                    self._last_obs[agent]
                )
            else:
                # No-op for agents whose option has already terminated
                prim_actions[agent] = np.zeros(_PRIM_ACTION_DIM, dtype=np.float32)

        # Primitive step
        obs, rewards, terminated, truncated, infos = self._env.step(prim_actions)
        self._primitive_steps += 1

        # ----------------------------------------------------------------
        # Process results for each agent alive at macro-step start
        # ----------------------------------------------------------------
        for agent in current_agents:
            if option_done[agent]:
                continue

            # Accumulate reward
            if agent in rewards:
                aggregate_rewards[agent] += float(rewards[agent])

            # Update latest observation
            if agent in obs:
                self._last_obs[agent] = obs[agent].copy()
                final_obs[agent] = obs[agent].copy()

            # Update info (keep the latest primitive info)
            if agent in infos:
                final_infos[agent] = dict(infos[agent])

            # Check underlying-env termination first
            env_terminated = terminated.get(agent, False)
            env_truncated = truncated.get(agent, False)
            if env_terminated or env_truncated:
                final_terminated[agent] = bool(env_terminated)
                final_truncated[agent] = bool(env_truncated)
                option_done[agent] = True
            else:
                # Increment option step counter and check option termination
                option_steps[agent] += 1
                if selected_options[agent].should_terminate(
                    self._last_obs[agent], option_steps[agent]
                ):
                    option_done[agent] = True

    # ------------------------------------------------------------------
    # Post-loop: handle agents that the env removed without explicit
    # terminated/truncated signals (edge case when the underlying env
    # empties mid-option and the agent was never seen in a terminated
    # dict during this macro-step).  Treating them as truncated preserves
    # PettingZoo semantics — the agent was alive at macro-step start but
    # absent at the end because the episode was cut short by the env.
    # ------------------------------------------------------------------
    for agent in current_agents:
        if (
            agent not in self._env.agents
            and not final_terminated[agent]
            and not final_truncated[agent]
        ):
            final_truncated[agent] = True

    self._macro_steps += 1

    # Update wrapper-level agent list from underlying env
    self.agents = list(self._env.agents)

    # ------------------------------------------------------------------
    # Attach temporal abstraction metadata to every agent's info dict
    # ------------------------------------------------------------------
    ta_ratio = self.temporal_abstraction_ratio
    for agent in current_agents:
        final_infos[agent]["temporal_abstraction"] = {
            "macro_steps": self._macro_steps,
            "primitive_steps": self._primitive_steps,
            "ratio": ta_ratio,
            "option_name": selected_options[agent].name,
            "option_steps": option_steps.get(agent, 0),
        }

    return final_obs, aggregate_rewards, final_terminated, final_truncated, final_infos