File size: 72,865 Bytes
2010c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
import gc
import io
import logging
import pickle
import shutil
import traceback
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, field, replace
from functools import reduce
from multiprocessing import shared_memory
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, cast

import numpy as np
import torch
import torch.distributed.checkpoint as dist_cp
import torch.multiprocessing as mp
from packaging import version
from torch.distributed import _remote_device
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.api import (
    FullOptimStateDictConfig,
    FullStateDictConfig,
    ShardedOptimStateDictConfig,
    ShardedStateDictConfig,
)
from torch.futures import Future

try:
    from torch.distributed.fsdp.flat_param import FlatParamHandle  # type: ignore
except ModuleNotFoundError:
    from torch.distributed.fsdp._flat_param import FlatParamHandle  # type: ignore

from . import util

from .aliases import PathOrStr
from .config import BaseConfig, ShardedCheckpointerType, TrainConfig
from .exceptions import OLMoCheckpointError
from .optim import Optimizer, fix_optim_state_dict
from .safetensors_util import safetensors_file_to_state_dict
from .torch_util import (
    barrier,
    gc_cuda,
    get_fs_local_rank,
    get_global_rank,
    get_world_size,
)
from .util import (
    _get_s3_client,
    default_thread_count,
    dir_is_empty,
    get_bytes_range,
    get_progress_bar,
    resource_path,
    upload,
    wait_for,
)

__all__ = [
    "save_fsdp_model_and_optim_state",
    "load_fsdp_model_and_optim_state",
    "load_fsdp_optim_state",
    "save_state_dict",
    "load_state_dict",
    "load_model_state",
    "RemoteFileSystemWriter",
    "RemoteFileSystemReader",
    "Checkpointer",
    "FullCheckpointer",
    "TorchNewStyleShardedCheckpointer",
    "TorchLegacyShardedCheckpointer",
    "LocalShardedCheckpointer",
    "build_sharded_checkpointer",
]


log = logging.getLogger(__name__)

MODEL_AND_OPTIM_FOLDER = "model_and_optim"


def save_fsdp_model_and_optim_state(
    checkpoint_dir: PathOrStr,
    fsdp_model: FSDP,
    optim: Optimizer,
    *,
    upload_to: Optional[str] = None,
    save_overwrite: bool = False,
):
    """
    Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
    functions. This should be used during distributed training and should be called by all ranks.

    :param checkpoint_dir: The directory to save to.
    :param fsdp_model: The FSDP model.
    :param optim: The FSDP model's optimizer.
    :param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
    :param save_overwrite: Overwrite existing files.

    :raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
    """
    checkpoint_dir = Path(checkpoint_dir)
    target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
    if save_overwrite:
        if get_fs_local_rank() == 0:
            shutil.rmtree(target_dir, ignore_errors=True)
    elif not dir_is_empty(target_dir):
        raise FileExistsError(target_dir)
    barrier()
    if get_fs_local_rank() == 0:
        target_dir.mkdir(exist_ok=True, parents=True)
    barrier()
    with FSDP.state_dict_type(
        fsdp_model,
        state_dict_type=StateDictType.SHARDED_STATE_DICT,
        state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
        optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
    ):
        model_and_optim_state = {
            "model": fsdp_model.state_dict(),
            "optim": FSDP.optim_state_dict(fsdp_model, optim),
        }
        dist_cp.save_state_dict(
            model_and_optim_state,
            RemoteFileSystemWriter(
                target_dir,
                upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
                save_overwrite=save_overwrite,
            ),
        )


def load_fsdp_model_and_optim_state(
    checkpoint_dir: PathOrStr,
    fsdp_model: FSDP,
    optim: Optimizer,
    *,
    local_cache: Optional[PathOrStr] = None,
    load_optimizer_state: bool = True,
):
    """
    Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
    functions. This should be used during distributed training and should be called by all ranks.

    :param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
    :param fsdp_model: The FSDP model.
    :param optim: The FSDP model's optimizer.
    :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
        remote "directory" but there might be a cached version of the same artifacts.
    :param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.

    :raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
    """
    load_path = str(checkpoint_dir).rstrip("/")
    local_cache = None if local_cache is None else Path(local_cache)
    with FSDP.state_dict_type(
        fsdp_model,
        state_dict_type=StateDictType.SHARDED_STATE_DICT,
        state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
        optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
    ):
        # Load the model state dict in place.
        log.info("Loading model state...")
        model_state = {"model": fsdp_model.state_dict()}
        dist_cp.load_state_dict(
            model_state,
            RemoteFileSystemReader(
                f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
                local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
            ),
        )
        fsdp_model.load_state_dict(model_state["model"])

        if not load_optimizer_state:
            return

        # Load optim state dict in place.
        log.info("Loading sharded optimizer state...")
        optim_state = load_sharded_optimizer_state_dict(
            model_state_dict=model_state["model"],
            optimizer_key="optim",
            storage_reader=RemoteFileSystemReader(
                f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
                local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
            ),
        )
        del model_state
        gc_cuda()
        load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])


def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
    log.info("Flattening sharded optimizer state...")
    # NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
    if version.parse(torch.__version__) < version.parse("2.1.0"):
        flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim)  # type: ignore
    else:
        flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state)  # type: ignore
    del optim_state
    gc.collect()
    log.info("Loading flattened optimizer state...")
    # Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
    # which takes up unnecessary GPU memory.
    for state in flattened_osd["state"].values():
        for k in state.keys():
            v = state[k]
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device="cpu")
    gc_cuda()
    optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))


def save_state_dict(
    checkpoint_dir: PathOrStr,
    fname: str,
    state_dict: Dict[str, Any],
    *,
    upload_to: Optional[str] = None,
    save_overwrite: bool = False,
    synchronize: bool = True,
):
    """
    Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
    This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
    for each rank.

    :param checkpoint_dir: The directory to save to.
    :param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
    :param state_dict: The state dict to save.
    :param upload_to: Optional, a remote "directory" to upload the file to.
    :param save_overwrite: Overwrite existing files.
    :param synchronize: If ``False``, don't do any distributed synchronization. Use this when only calling
        this function from a single rank.

    :raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
    """
    checkpoint_dir = Path(checkpoint_dir)
    target_path = checkpoint_dir / fname
    if save_overwrite:
        target_path.unlink(missing_ok=True)
    elif target_path.is_file():
        raise FileExistsError(target_path)
    if synchronize:
        barrier()
    target_path.parent.mkdir(exist_ok=True, parents=True)
    if synchronize:
        barrier()
    torch.save(state_dict, target_path)
    if upload_to is not None:
        upload_target = f"{upload_to.rstrip('/')}/{fname}"
        log.info(f"Uploading {target_path} to {upload_target}...")
        upload(target_path, upload_target, save_overwrite=save_overwrite)


def load_state_dict(
    checkpoint_dir: PathOrStr,
    fname: str,
    *,
    local_cache: Optional[PathOrStr] = None,
    map_location: Optional[str] = None,
):
    """
    Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
    This can be used during distributed training or not.

    :param checkpoint_dir: A local or remote checkpoint directory.
    :param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
    :param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
        remote "directory" but there might be a cached version of the same artifacts.

    :raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
    """
    if fname.endswith(".pt"):
        # Try safetensors version first.
        try:
            path = resource_path(
                str(checkpoint_dir).rstrip("/"), fname[:-2] + "safetensors", local_cache=local_cache
            )
            return safetensors_file_to_state_dict(path, map_location=map_location)
        except FileNotFoundError:
            pass

    path = resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache)
    return torch.load(path, map_location=map_location)


def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
    """
    Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
    Note that ``model`` should not be wrapped with FSDP.
    """
    state_dict = {"model": model.state_dict()}
    dist_cp.load_state_dict(
        state_dict,
        RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
        no_dist=True,
    )
    model.load_state_dict(state_dict["model"])


class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
    """
    A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
    directly to a cloud bucket when ``upload_to`` is specified.
    """

    def __init__(
        self,
        path: PathOrStr,
        single_file_per_rank: bool = True,
        sync_files: bool = True,
        thread_count: Optional[int] = None,
        per_thread_copy_ahead: int = 10_000_000,
        upload_to: Optional[str] = None,
        save_overwrite: bool = False,
    ) -> None:
        if thread_count is not None and thread_count <= 0:
            raise ValueError("thread count must be at least 1")
        super().__init__(
            path,
            single_file_per_rank=single_file_per_rank,
            sync_files=sync_files,
            # NOTE: we default to 1 thread here instead of whatever `default_thread_count()`
            # returns because uploading big checkpoint files with multiple threads causes
            # boto3 to fail in weird ways.
            thread_count=thread_count or 1,
            per_thread_copy_ahead=per_thread_copy_ahead,
        )
        self.upload_to = None if upload_to is None else upload_to.rstrip("/")
        self.save_overwrite = save_overwrite

    def write_data(
        self,
        plan: dist_cp.SavePlan,
        planner: dist_cp.SavePlanner,
    ) -> Future[List[WriteResult]]:
        fut = super().write_data(plan, planner)
        if self.upload_to is not None:
            files_to_upload = set()
            for write_result in fut.wait():
                files_to_upload.add(write_result.storage_data.relative_path)

            # Create the global S3 client up front to work around a threading issue in boto.
            if self.upload_to.startswith("s3://"):
                _get_s3_client("s3")
            elif self.upload_to.startswith("r2://"):
                _get_s3_client("r2")

            with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
                futures = []
                for fname in files_to_upload:
                    source = self.path / fname
                    target = f"{self.upload_to}/{fname}"
                    log.info(f"Uploading {source} to {target}...")
                    futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
                for f in as_completed(futures):
                    try:
                        f.result()
                    except BaseException:
                        # NOTE: we might get an error here that can't be pickled, which causes a different failure
                        # later when PyTorch tries to reduce that error across ranks. So here we just make
                        # sure we're raising a simple error type that can be pickled.
                        raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
        return fut

    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
        super().finish(metadata, results)
        if self.upload_to is not None:
            source = self.path / ".metadata"
            target = f"{self.upload_to}/.metadata"
            log.info(f"Uploading {source} to {target}...")
            upload(source, target, save_overwrite=self.save_overwrite)


class RemoteFileSystemReader(dist_cp.StorageReader):
    """
    A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
    that can read data directly from cloud storage as well as a local directory.
    """

    def __init__(
        self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
    ):
        super().__init__()
        if thread_count is not None and thread_count <= 0:
            raise ValueError("thread count must be at least 1")
        self.path = str(path).rstrip("/")
        self.cache = None if local_cache is None else Path(local_cache)
        self.thread_count = thread_count or default_thread_count()
        self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
        self._metadata: Optional[Metadata] = None

    def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
        if self.cache is not None and (path := self.cache / relative_path).is_file():
            return get_bytes_range(path, offset, length)
        else:
            return get_bytes_range(f"{self.path}/{relative_path}", offset, length)

    def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
        sinfo = self.storage_data[read_item.storage_index]
        content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
        return (read_item, content)

    def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
        # Create the global S3 client up front to work around a threading issue in boto.
        if isinstance(self.path, str):
            if self.path.startswith("s3://"):
                _get_s3_client("s3")
            elif self.path.startswith("r2://"):
                _get_s3_client("r2")

        with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
            read_item_content_futures = []
            for read_item in plan.items:
                read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
            read_item_content_results = []
            for f in as_completed(read_item_content_futures):
                try:
                    read_item_content_results.append(f.result())
                except BaseException:
                    # NOTE: we might get an error here that can't be pickled, which causes a different failure
                    # later when PyTorch tries to reduce that error across ranks. So here we just make
                    # sure we're raising a simple error type that can be pickled.
                    raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")

        # Modified from `FileSystemReader.read_data()`
        for read_item, content in read_item_content_results:
            bytes = io.BytesIO(content)
            bytes.seek(0)
            if read_item.type == LoadItemType.BYTE_IO:
                planner.load_bytes(read_item, bytes)
            else:
                tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
                tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
                target_tensor = planner.resolve_tensor(read_item).detach()

                assert (
                    target_tensor.size() == tensor.size()
                ), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
                target_tensor.copy_(tensor)
                planner.commit_tensor(read_item, target_tensor)

        fut: Future = Future()
        fut.set_result(None)
        return fut

    def read_metadata(self) -> Metadata:
        if self._metadata is None:
            with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
                self._metadata = pickle.load(metadata_file)
        return self._metadata

    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
        del is_coordinator
        self.storage_data = metadata.storage_data
        assert self.storage_data is not None

    def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
        return plan

    def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
        return global_plan


class Checkpointer(metaclass=ABCMeta):
    def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None):
        self.cfg = cfg
        self.thread_count = thread_count or default_thread_count()

    @abstractmethod
    def save_checkpoint(
        self,
        dir: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        train_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        """
        Restores a checkpoint to the model and optimizer. Returns the remaining trainer state.
        """
        raise NotImplementedError

    def unshard_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
        """
        Unshard a checkpoint.

        Note this is not marked abstract because child classes are not required to implemented this.
        """
        del load_path, local_cache, load_optimizer_state, load_trainer_state, device
        raise NotImplementedError

    @contextmanager
    def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
        # Make sure checkpoint directory doesn't exist unless it's okay to overwrite it.
        checkpoint_dir = Path(dir)
        if not dir_is_empty(checkpoint_dir):
            if self.cfg.save_overwrite:
                if get_fs_local_rank() == 0:
                    shutil.rmtree(checkpoint_dir, ignore_errors=True)
            else:
                raise FileExistsError(checkpoint_dir)
        # No need to mkdir here since we'll directly replace the temporary directory with
        # this directory below.
        barrier()

        # Prepare temporary directory. We don't have to be as careful here, we can
        # just remove it if it already exists.
        checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
        if get_fs_local_rank() == 0:
            shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
            checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)

        barrier()

        # Yield temporary directory for `.save_checkpoint()` to use.
        yield checkpoint_dir_tmp

        barrier()

        # Finally if all went well replace the temporary directory with the actual
        # checkpoint directory.
        if get_fs_local_rank() == 0:
            # Replace temp directory with target checkpoint directory.
            try:
                checkpoint_dir_tmp.replace(checkpoint_dir)
            except FileNotFoundError:
                # Caught when another (file-system) local rank 0 has already replaced the tmp directory.
                # This can happen when nodes are saving to a common NFS drive but otherwise have distinct
                # file-systems.
                if not checkpoint_dir.exists():
                    raise

        # In the cases where we're using a shared NFS drive between ranks to save checkpoints,
        # replacing the temp directory with the final directory from rank 0 might not be immediately
        # realized in the file systems of the other ranks.
        # So we wait here across all ranks until that final checkpoint directory is visible.
        wait_for(lambda: checkpoint_dir.exists(), "Waiting for checkpoint directory", timeout=10.0)

        barrier()

    def _save_config(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
        if get_global_rank() == 0:
            log.info("Saving config...")
            self.cfg.save(config_path := Path(dir) / "config.yaml")
            if upload_to is not None:
                upload_target = f"{upload_to}/config.yaml"
                log.info(f"Uploading {config_path} to {upload_target}")
                upload(config_path, upload_target, save_overwrite=self.cfg.save_overwrite)


class FullCheckpointer(Checkpointer):
    """
    A :class:`Checkpointer` that saves a single full model and optimizer state dictionary.
    """

    def save_checkpoint(
        self,
        dir: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        with self._temporary_wd(dir) as checkpoint_dir:
            with FSDP.state_dict_type(
                fsdp_model,
                state_dict_type=StateDictType.FULL_STATE_DICT,
                state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
                optim_state_dict_config=FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
            ):
                # We'll write the model and optimizer state dicts individually to reduce (CPU) memory consumption.
                # First the model state.
                model_state_dict = fsdp_model.state_dict()
                if get_global_rank() == 0:
                    log.info("Saving model state...")
                    save_state_dict(
                        checkpoint_dir,
                        "model.pt",
                        model_state_dict,
                        upload_to=upload_to,
                        save_overwrite=self.cfg.save_overwrite,
                        synchronize=False,
                    )
                del model_state_dict
                barrier()

                # Then the optimizer state.
                optim_state_dict = FSDP.optim_state_dict(fsdp_model, optim)
                if get_global_rank() == 0:
                    log.info("Saving optim state...")
                    save_state_dict(
                        checkpoint_dir,
                        "optim.pt",
                        optim_state_dict,
                        upload_to=upload_to,
                        save_overwrite=self.cfg.save_overwrite,
                        synchronize=False,
                    )
                del optim_state_dict
                barrier()

            # Save trainer state.
            if get_global_rank() == 0:
                log.info("Saving trainer state...")
                save_state_dict(
                    checkpoint_dir,
                    "train.pt",
                    trainer_state,
                    upload_to=upload_to,
                    save_overwrite=self.cfg.save_overwrite,
                    synchronize=False,
                )
            # Save config.
            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        with FSDP.state_dict_type(
            fsdp_model,
            state_dict_type=StateDictType.FULL_STATE_DICT,
            state_dict_config=FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
            optim_state_dict_config=FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
        ):
            with torch.no_grad():
                # fill everything with NaN, so we can check afterwards that every parameter has been restored
                for module_name, module in fsdp_model.named_modules():
                    if not isinstance(module, FSDP):
                        continue
                    for param in module.params:
                        param.fill_(torch.nan)

                # restore params from checkpoint
                state_dict_to_load = load_state_dict(
                    load_path, "model.pt", local_cache=local_cache, map_location="cpu"
                )
                (
                    state_dict_to_load,
                    og_keys_to_new,
                ) = fsdp_model._fsdp_wrapped_module._make_state_dict_compatible(state_dict_to_load)

                for module_name, module in fsdp_model.named_modules():
                    if not isinstance(module, FSDP):
                        continue
                    for param in module.params:
                        assert param._is_flat_param
                        for fqn, spi in zip(param._fqns, param._shard_param_infos):
                            if not spi.in_shard:
                                continue
                            key = f"{module_name}.{fqn}"
                            key = key.replace("_fsdp_wrapped_module.", "")
                            key = key.lstrip(".")
                            t = state_dict_to_load[key]
                            t = t.flatten()
                            param[spi.offset_in_shard : spi.offset_in_shard + spi.numel_in_shard].copy_(
                                t[spi.intra_param_start_idx : spi.intra_param_end_idx + 1]
                            )

                # make sure that every parameter has been restored
                for module_name, module in fsdp_model.named_modules():
                    if not isinstance(module, FSDP):
                        continue
                    for param in module.params:
                        if torch.isnan(param).any():
                            raise ValueError(
                                f"Module '{module_name}' contains NaNs, this is likely a bug restoring from full checkpoints"
                            )

            # Load optimizer state.
            if load_optimizer_state:
                optim_state_dict_to_load = load_state_dict(
                    load_path, "optim.pt", local_cache=local_cache, map_location="cpu"
                )
                optim_state_dict_to_load = self._make_optim_state_dict_compatible(
                    optim_state_dict_to_load,
                    og_keys_to_new,
                )
                load_fsdp_optim_state(fsdp_model, optim, optim_state_dict_to_load)
                del optim_state_dict_to_load

            # Load other state.
            try:
                trainer_state = load_state_dict(load_path, "train.pt", local_cache=local_cache)
            except FileNotFoundError:
                # for backwards compatibility
                trainer_state = load_state_dict(load_path, "other.pt", local_cache=local_cache)
        barrier()
        return trainer_state

    def _make_optim_state_dict_compatible(
        self, optim_state_dict: Dict[str, Any], og_keys_to_new: Dict[str, Set[str]]
    ) -> Dict[str, Any]:
        # This state dict comes in two forms: one where the state keys are integers and one where the
        # keys are fully qualified parameter names. The latter case is easier to deal with here so we
        # first transform the integer key form into the FQN key form.
        if isinstance(optim_state_dict["param_groups"][0]["params"][0], int):
            id_to_fqn: Dict[int, str] = {}
            for group in optim_state_dict["param_groups"]:
                new_param_names = []
                for fqn, id in zip(group["param_names"], group["params"]):
                    fqn = fqn.replace("_fsdp_wrapped_module.", "")
                    id_to_fqn[id] = fqn
                    new_param_names.append(fqn)
                group["param_names"] = new_param_names
                group["params"] = new_param_names
            for id in list(optim_state_dict["state"].keys()):
                optim_state_dict["state"][id_to_fqn[id]] = optim_state_dict["state"].pop(id)
        else:
            # Otherwise we still want to clean up the param names to remove the "_fsdp_wrapped_module." prefix.
            for group in optim_state_dict["param_groups"]:
                group["param_names"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["param_names"]]
                group["params"] = [fqn.replace("_fsdp_wrapped_module.", "") for fqn in group["params"]]
                assert group["param_names"] == group["params"]
            for key in list(optim_state_dict["state"].keys()):
                optim_state_dict["state"][key.replace("_fsdp_wrapped_module.", "")] = optim_state_dict[
                    "state"
                ].pop(key)

        # Now we can transform the state dict by renaming parameters according to `og_keys_to_new`.
        # First fix param names in the state.
        for og_key, new_keys in og_keys_to_new.items():
            og_state = optim_state_dict["state"].pop(og_key, None)
            if og_state is None:
                continue
            for i, new_key in enumerate(new_keys):
                if i == len(new_keys) - 1:
                    optim_state_dict["state"][new_key] = og_state
                else:
                    optim_state_dict["state"][new_key] = deepcopy(og_state)
        # Now fix param names in the param groups.
        for group in optim_state_dict["param_groups"]:
            og_names = group["params"]
            new_names = []
            for og_key in og_names:
                for new_key in og_keys_to_new[og_key]:
                    new_names.append(new_key)
            group["params"] = new_names
            group["param_names"] = new_names

        return optim_state_dict

    def load_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]]]:
        device = device if device is not None else torch.device("cpu")
        model_state = load_state_dict(load_path, "model.pt", local_cache=local_cache, map_location=device)  # type: ignore
        optim_state = None
        if load_optimizer_state:
            optim_state = load_state_dict(load_path, "optim.pt", local_cache=local_cache, map_location=device)  # type: ignore
        return model_state, optim_state


class TorchNewStyleShardedCheckpointer(Checkpointer):
    """
    A sharded :class:`Checkpointer` that uses PyTorch's new distributed checkpointing functionality.
    """

    def save_checkpoint(
        self,
        dir: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        with self._temporary_wd(dir) as checkpoint_dir:
            # Save model and optim state.
            save_fsdp_model_and_optim_state(
                checkpoint_dir,
                fsdp_model,
                optim,
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            # Save trainer state.
            log.info("Saving trainer state...")
            save_state_dict(
                checkpoint_dir,
                f"train/rank{get_global_rank()}.pt",
                trainer_state,
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            # Save config.
            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        # Load model and optimizer state in place.
        log.info("Loading model and optimizer state...")
        load_fsdp_model_and_optim_state(
            load_path,
            fsdp_model,
            optim,
            local_cache=local_cache,
            load_optimizer_state=load_optimizer_state,
        )

        # Load trainer state dict.
        log.info("Loading trainer state...")
        try:
            trainer_state = load_state_dict(
                load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache
            )
        except FileNotFoundError:
            # Fall back to rank 0 train state.
            # This can happen when we're restoring a checkpoint with a different world size.
            trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)
        barrier()
        return trainer_state


class TorchLegacyShardedCheckpointer(Checkpointer):
    """
    A sharded :class:`Checkpointer` that just uses `torch.save()` with extra logic for handling FSDP model
    and optim state.

    The world size must be kept consistent when using this checkpointer.
    """

    def save_checkpoint(
        self,
        dir: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        with self._temporary_wd(dir) as checkpoint_dir:
            with FSDP.state_dict_type(
                fsdp_model,
                state_dict_type=StateDictType.SHARDED_STATE_DICT,
                state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
                optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
            ):
                state_dict = {
                    "model": fsdp_model.state_dict(),
                    "optim": FSDP.optim_state_dict(fsdp_model, optim),
                    **trainer_state,
                }
                save_state_dict(
                    checkpoint_dir,
                    f"rank{get_global_rank()}.pt",
                    state_dict,
                    upload_to=upload_to,
                    save_overwrite=self.cfg.save_overwrite,
                )

            # Save config.
            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        with FSDP.state_dict_type(
            fsdp_model,
            state_dict_type=StateDictType.SHARDED_STATE_DICT,
            state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
            optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
        ):
            # Deserialize state dict.
            state_dict = load_state_dict(
                load_path, f"rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
            )

            # Load model and optimizer state.
            log.info("Loading model state...")
            fsdp_model.load_state_dict(state_dict["model"])
            del state_dict["model"]
            if load_optimizer_state:
                log.info("Loading optimizer state...")
                load_fsdp_optim_state(fsdp_model, optim, state_dict["optim"])
            del state_dict["optim"]

        barrier()
        return state_dict

    def unshard_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
        assert local_cache is None, "this method currently only supports local files"
        full_state_dict = self._unshard(load_path, device or torch.device("cpu"), skip_keys={"rng"})
        model_state = full_state_dict.pop("model")
        optim_state = full_state_dict.pop("optim")
        return (
            model_state,
            optim_state if load_optimizer_state else None,
            full_state_dict if load_trainer_state else None,
        )

    def _copy_sharded_tensors_to_shared_mem(self, state: Dict, world_size: int, rank: int, key: Tuple):
        key = tuple() if key is None else key
        if isinstance(state, (list, tuple, set)):
            for i, sub_state in enumerate(state):
                self._copy_sharded_tensors_to_shared_mem(sub_state, world_size, rank, key + (i,))
        elif isinstance(state, dict):
            for name in state.keys():
                self._copy_sharded_tensors_to_shared_mem(state[name], world_size, rank, key + (name,))
        elif isinstance(state, ShardedTensor):
            self._copy_sharded_tensor_to_shared_mem(state, world_size, rank, key)
            return
        else:
            return

    def _get_shard_placement_and_rank_sizes(
        self, shards_metadata: List[ShardMetadata], world_size: int
    ) -> Tuple[Dict[ShardMetadata, Tuple[int, int]], List[int]]:
        def shard_size(shard_md):
            return reduce((lambda x, y: x * y), shard_md.shard_sizes)  # type: ignore[attr-defined]

        rank_sizes = [0 for _ in range(world_size)]
        shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
        for shard_md in shards_metadata:
            shard_rank = cast(_remote_device, shard_md.placement).rank()
            assert shard_rank is not None
            if shard_rank >= world_size:
                raise RuntimeError(f"Shard rank {shard_rank} exceeds world size {world_size}")

            shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
            rank_sizes[shard_rank] += shard_size(shard_md)

        return shard_placement, rank_sizes

    def _copy_sharded_tensor_to_shared_mem(
        self, sharded_tensor: ShardedTensor, world_size: int, rank: int, key: Tuple
    ) -> Any:
        shard0_md = sharded_tensor.metadata()
        shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
            shard0_md.shards_metadata, world_size
        )

        rank_size = rank_sizes[rank]
        assert rank_size >= 0
        if rank_size == 0:
            return

        assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
        numpy_type = np.float32

        sharded_memory_name = "-".join(key + (str(rank),))

        shm = shared_memory.SharedMemory(
            create=True, size=rank_size * np.dtype(numpy_type).itemsize, name=sharded_memory_name
        )
        np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)

        for local_shard in sharded_tensor.local_shards():
            shard_rank = cast(_remote_device, local_shard.metadata.placement).rank()
            assert shard_rank == rank

            src = local_shard.tensor.flatten()
            shard_offset = shard_placement[local_shard.metadata][1]

            np_arr[shard_offset : shard_offset + src.numel()] = src.numpy()

        shm.close()

    def _copy_sharded_data_to_shared_mem(self, world_size: int, shard_filepath: Path):
        shard_number = int(shard_filepath.name[4:-3])
        log.info("Starting unsharding shard number %d to shared memory", shard_number)

        with self._patch_sharded_tensor_load():
            shard = torch.load(shard_filepath, map_location="cpu")
            log.debug("Done loading shard number %d", shard_number)

        self._copy_sharded_tensors_to_shared_mem(
            shard, world_size, shard_number, (str(shard_filepath.parent).replace("/", "_"),)
        )
        log.info("Done unsharding shard number %d to shared memory", shard_number)

    def _unshard_using_sharded_mem(
        self, state: Any, world_size: int, device: torch.device, shard_dir: PathOrStr
    ) -> Any:
        return self._unshard_state_using_shared_mem(state, world_size, device, (str(shard_dir).replace("/", "_"),))

    def _unshard_state_using_shared_mem(
        self, state: Any, world_size: int, device: torch.device, key: Tuple
    ) -> Any:
        if isinstance(state, (list, tuple, set)):
            return state.__class__(
                self._unshard_state_using_shared_mem(sub_state, world_size, device, key + (i,))
                for i, sub_state in enumerate(state)
            )
        elif isinstance(state, dict):
            return {
                name: self._unshard_state_using_shared_mem(state[name], world_size, device, key + (name,))
                for name in state.keys()
            }
        elif isinstance(state, ShardedTensor):
            return self._unshard_tensor_using_shared_mem(state, world_size, device, key)
        elif isinstance(state, torch.Tensor):
            return state.to(device=device)
        else:
            return state

    def _unshard_tensor_using_shared_mem(
        self, sharded_tensor: ShardedTensor, world_size: int, device: torch.device, key: Tuple
    ) -> torch.Tensor:
        shard0_md = sharded_tensor.metadata()

        def shard_size(shard_md):
            return reduce((lambda x, y: x * y), shard_md.shard_sizes)  # type: ignore[attr-defined]

        shard_placement, rank_sizes = self._get_shard_placement_and_rank_sizes(
            shard0_md.shards_metadata, world_size
        )

        assert shard0_md.tensor_properties.dtype == torch.float32, "Expected sharded tensor to be fp32"
        numpy_type = np.float32

        out = torch.empty(
            *sharded_tensor.metadata().size, dtype=sharded_tensor.metadata().tensor_properties.dtype, device=device
        )
        dims = len(sharded_tensor.metadata().size)
        for shard_md, (rank, rank_offset) in shard_placement.items():
            if rank >= world_size:
                raise RuntimeError(f"Shard rank {rank} exceeds world size {world_size}")

            sharded_memory_name = "-".join(key + (str(rank),))
            shm = shared_memory.SharedMemory(name=sharded_memory_name)

            rank_size = rank_sizes[rank]
            assert rank_size >= 0
            if rank_size == 0:
                continue

            np_arr = np.ndarray((rank_size,), dtype=numpy_type, buffer=shm.buf)

            tensor = torch.from_numpy(np_arr)[rank_offset : rank_offset + shard_size(shard_md)]
            tensor = tensor.view(shard_md.shard_sizes)

            out_narrow_view = out
            for dim in range(dims):
                out_narrow_view = out_narrow_view.narrow(
                    dim,
                    shard_md.shard_offsets[dim],
                    shard_md.shard_sizes[dim],
                )

            out_narrow_view.copy_(tensor)

            shm.close()
            shm.unlink()

        return out

    @contextmanager
    def _patch_sharded_tensor_load(self):
        """
        Monkeypatch for torch's ShardedTensor, so we can unpickle without having torch.distributed set up.
        """

        def _rebuild_from_type_v2_monkey(func, new_type, args, state):
            ret = func(*args)
            if type(ret) is not new_type:
                ret = ret.as_subclass(new_type)

            # Shortcut the construction of ShardedTensor
            # This is in the top 5 of my worst hacks.
            if isinstance(ret, ShardedTensor):
                ret._local_shards, ret._metadata, _, ret._sharding_spec, ret._init_rrefs = state
                return ret

            # The rest of this function ought to be in the top 5 of somebody else's worst hacks.
            # Tensor does define __setstate__ even though it doesn't define
            # __getstate__. So only use __setstate__ if it is NOT the one defined
            # on Tensor
            if getattr(ret.__class__, "__setstate__", torch.Tensor.__setstate__) is not torch.Tensor.__setstate__:
                ret.__setstate__(state)
            else:
                ret = torch._utils._set_obj_state(ret, state)
            return ret

        original_rebuild_from_type_v2 = torch._tensor._rebuild_from_type_v2
        try:
            torch._tensor._rebuild_from_type_v2 = _rebuild_from_type_v2_monkey
            yield
        finally:
            torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2

    def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
        """
        The current unsharding implementation consists of:

        1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
        2. Loading 1 shard on the main process as a base unsharded object.
        3. Using the sharded tensors in shared memory to populate the base unsharded object.

        This implementation replaced a prior implementation that instead loaded
        all shards using threads, because that implementation turned out to
        be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
        The current implementation is slower than the old one in many scenarios,
        but is significantly faster in the above mentioned case (e.g. 30 minutes)
        if there are enough CPUs.
        """

        input_dir = Path(input_dir)
        skip_keys = skip_keys or set()

        shard_filepaths = list(input_dir.glob("rank*.pt"))
        world_size = len(shard_filepaths)
        if world_size == 0:
            raise RuntimeError("No shards found for unsharding")

        log.info("Number of shards: %d", world_size)
        shard_size_gb = shard_filepaths[0].stat().st_size / (1024 * 1024 * 1024)
        min_ram_required_estimate_gb = shard_size_gb * world_size
        log.info(
            "Shards are %.2fGB each, at least %.2fGB RAM is required", shard_size_gb, min_ram_required_estimate_gb
        )

        log.info("Copying sharded tensors to shared memory using multiple processes")
        # Copy sharded data to shared memory using multiple processes, so this process can load
        # from memory rather than disk. We spawn a new process instead of forking since shared memory
        # appears to get deleted when forked processes end for some reason.
        executor = ProcessPoolExecutor(
            mp_context=mp.get_context("spawn"), initializer=util.prepare_cli_environment
        )
        futures = []
        for shard_filepath in shard_filepaths:
            shard_rank = int(shard_filepath.name[4:-3])

            if shard_rank >= world_size:
                raise RuntimeError(
                    f"Shard rank {shard_rank} of file {shard_filepath} exceeds world size {world_size}"
                )

            futures.append(executor.submit(self._copy_sharded_data_to_shared_mem, world_size, shard_filepath))

        for f in as_completed(futures):
            f.result()
        executor.shutdown()

        log.info("Loading a shard on the main process to be unsharded state")
        with self._patch_sharded_tensor_load():
            state = torch.load(shard_filepaths[0], map_location="cpu")

        for key in skip_keys:
            if key in state:
                del state[key]

        log.info("Unsharding from %d shards ...", world_size)
        return self._unshard_using_sharded_mem(state, world_size, device, input_dir)


@dataclass
class _LocalShardedCheckpointerMetadata(BaseConfig):
    world_size: int = field(default_factory=get_world_size)


@dataclass
class _FlatParamShard:
    full_shape: torch.Size
    shard_offsets: Tuple[int, int]
    shard_data: Optional[torch.Tensor]

    def copy_into(self, full_tensor: torch.Tensor) -> None:
        assert self.shard_data is not None
        full_tensor_shard_view = full_tensor.view(-1)[self.shard_offsets[0] : self.shard_offsets[1] + 1]
        assert self.shard_data.shape == full_tensor_shard_view.shape
        full_tensor_shard_view.copy_(self.shard_data)


class LocalShardedCheckpointer(Checkpointer):
    """
    A sharded :class:`Checkpointer` that directly saves the local FSDP flat params data.
    The optimizer state is saved directly with `torch.save()` without reformatting via FSDP methods.

    The world size must be kept consistent when using this checkpointer. However, you can easily
    reconstruct a full unsharded model and/or optimizer state dictionary from a single Python process
    using :meth:`unshard_checkpoint()` (no distributed initialization required).
    """

    # These correspond to metadata attributes on `torch.distributed.fsdp.flat_param.FlatParameter`.
    _FLAT_PARAM_METADATA_TO_SAVE = (
        "_fqns",
        "_shard_param_offsets",
        "_shard_indices",
        "_numels",
        "_numels_with_padding",
        "_shapes",
        "_shard_numel_padded",
        "_shard_param_infos",
    )

    def _fsdp_modules(self, fsdp_model: FSDP) -> List[Tuple[str, FSDP]]:
        """
        Returns a list of FSDP modules with their FQN.
        """
        modules = []
        for name, module in fsdp_model.named_modules():
            if isinstance(module, FSDP):
                modules.append((name, module))
        return modules

    def _prepare_fsdp_model(self, fsdp_model: FSDP) -> None:
        from torch.distributed.fsdp._runtime_utils import _lazy_init

        # TODO (epwalsh): I'm not sure if this is necessary, but this is what PyTorch does before saving/loading
        # an FSDP state dict through the built-in methods.
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        _lazy_init(fsdp_model, fsdp_model)

    def _fsdp_handles(self, fsdp_model: FSDP) -> List[FlatParamHandle]:
        if version.parse(torch.__version__) < version.parse("2.1.0"):
            return fsdp_model._handles  # type: ignore
        elif version.parse(torch.__version__) < version.parse("2.3.0"):
            # Handle could be None if the FSDP wrapper doesn't manage any parameters.
            if hasattr(fsdp_model, "_handle") and fsdp_model._handle is not None:
                return [fsdp_model._handle]  # type: ignore
            else:
                return []
        else:
            # Need to verify FSDP internals with newer versions.
            raise NotImplementedError

    @torch.no_grad()
    def _get_flat_param_state_to_save(self, fsdp_model: FSDP) -> Dict[str, Any]:
        self._prepare_fsdp_model(fsdp_model)
        module_data = []
        for module_fqn, fsdp_module in self._fsdp_modules(fsdp_model):
            handle_data = []
            for handle in self._fsdp_handles(fsdp_module):
                data: Dict[str, Any] = {}
                # This is a `FlatParameter` instance.
                # See `torch.distributed.fsdp.flat_param` for the API.
                flat_param = handle.flat_param
                data["flat_param.data"] = flat_param.detach()
                for key in self._FLAT_PARAM_METADATA_TO_SAVE:
                    if hasattr(flat_param, key):
                        data[f"flat_param.{key}"] = getattr(flat_param, key)
                handle_data.append(data)
            module_data.append({"handles": handle_data, "name": module_fqn})
        return {"modules": module_data}

    @torch.no_grad()
    def _load_flat_param_state(self, fsdp_model: FSDP, model_state: Dict[str, Any]):
        """Load the state produced from `self._get_flat_param_state_to_save()`."""
        self._prepare_fsdp_model(fsdp_model)
        fsdp_modules = self._fsdp_modules(fsdp_model)
        assert len(model_state["modules"]) == len(fsdp_modules)
        for (_, fsdp_module), module_data in zip(fsdp_modules, model_state["modules"]):
            handles = self._fsdp_handles(fsdp_module)
            assert len(handles) == len(module_data["handles"])
            for handle, data in zip(handles, module_data["handles"]):
                flat_param = handle.flat_param
                # Make sure metadata matches.
                for key in self._FLAT_PARAM_METADATA_TO_SAVE:
                    if hasattr(flat_param, key):
                        assert getattr(flat_param, key) == data[f"flat_param.{key}"]
                # Load the flat sharded data.
                flat_param.copy_(data["flat_param.data"])

    def _save_metadata(self, dir: PathOrStr, *, upload_to: Optional[str] = None) -> None:
        if get_fs_local_rank() == 0:
            log.info("Saving metadata...")
            metadata = _LocalShardedCheckpointerMetadata()
            metadata.save(metadata_path := Path(dir) / "metadata.yaml")
            if upload_to is not None and get_global_rank() == 0:
                upload_target = f"{upload_to}/metadata.yaml"
                log.info(f"Uploading {metadata_path} to {upload_target}")
                upload(metadata_path, upload_target, save_overwrite=self.cfg.save_overwrite)

    def _load_metadata(
        self, load_path: PathOrStr, *, local_cache: Optional[PathOrStr] = None
    ) -> _LocalShardedCheckpointerMetadata:
        metadata_path = resource_path(load_path, "metadata.yaml", local_cache=local_cache)
        return _LocalShardedCheckpointerMetadata.load(metadata_path)

    def save_checkpoint(
        self,
        dir: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        trainer_state: Dict[str, Any],
        *,
        upload_to: Optional[str] = None,
    ) -> None:
        with self._temporary_wd(dir) as checkpoint_dir:
            # Gather local FSDP flat params data to save.
            # We also save some flat param metadata like the corresponding fully qualified names (fqns)
            # of each original parameter so we can validate that the sharding is the same when loading
            # one of these checkpoints.
            log.info("Saving local FSDP flat params data...")
            save_state_dict(
                checkpoint_dir,
                f"model/rank{get_global_rank()}.pt",
                self._get_flat_param_state_to_save(fsdp_model),
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            # Save optimizer state.
            log.info("Saving local optimizer state...")
            save_state_dict(
                checkpoint_dir,
                f"optim/rank{get_global_rank()}.pt",
                optim.state_dict(),
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            # Save trainer state.
            log.info("Saving trainer state...")
            save_state_dict(
                checkpoint_dir,
                f"train/rank{get_global_rank()}.pt",
                trainer_state,
                upload_to=upload_to,
                save_overwrite=self.cfg.save_overwrite,
            )

            # Save metadata.
            self._save_metadata(checkpoint_dir, upload_to=upload_to)

            # Save config. We do this last b/c the presence of a config in a remote checkpoint
            # "directory" indicates that the folder is valid, as a opposed to a partially
            # uploaded checkpoint directory that failed before completing.
            self._save_config(checkpoint_dir, upload_to=upload_to)

    def restore_checkpoint(
        self,
        load_path: PathOrStr,
        fsdp_model: FSDP,
        optim: Optimizer,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
    ) -> Dict[str, Any]:
        # Load metadata and make sure checkpoint is compatible.
        metadata = self._load_metadata(load_path, local_cache=local_cache)
        assert metadata.world_size == get_world_size()

        # Load local FSDP flat param data.
        log.info("Loading local FSDP flat params data...")
        model_state = load_state_dict(
            load_path, f"model/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
        )
        self._load_flat_param_state(fsdp_model, model_state)
        del model_state

        # Load local optim state.
        if load_optimizer_state:
            log.info("Loading local optimizer state...")
            optim_state = load_state_dict(
                load_path, f"optim/rank{get_global_rank()}.pt", local_cache=local_cache, map_location="cpu"
            )
            # HACK/TODO (epwalsh): When we use adaptive clipping we track the 'grad_norm_exp_avg' for every param
            # in every rank, and keep this in the optimizer state. But this causes issues when loading the
            # state since torch sees the state is non-empty for some params which would normally be empty,
            # and then assumes it should have all of the other state tensors for that param, which is doesn't.
            # So for now we just remove 'grad_norm_exp_avg' everywhere from the state, which resets that metric.
            # Not the end of the world but there's probably a better way around this without resetting
            # the metric.
            for param_id in list(optim_state["state"].keys()):
                state = optim_state["state"][param_id]
                if "grad_norm_exp_avg" in state:
                    del state["grad_norm_exp_avg"]
                if len(state) == 0:
                    del optim_state["state"][param_id]
            optim.load_state_dict(optim_state)
            del optim_state

        # Load local trainer state.
        log.info("Loading local trainer state...")
        trainer_state = load_state_dict(load_path, f"train/rank{get_global_rank()}.pt", local_cache=local_cache)
        barrier()
        return trainer_state

    def _iter_flat_param_shards(
        self, model_state: Dict[str, Any]
    ) -> Generator[Tuple[str, _FlatParamShard], None, None]:
        for module_data in model_state["modules"]:
            module_prefix = module_data["name"].replace("_fsdp_wrapped_module.", "")
            for handle in module_data["handles"]:
                flat_data = handle["flat_param.data"]
                if (num_padding := handle["flat_param._shard_numel_padded"]) > 0:
                    # If there's padding in the flat param it should be on the right.
                    assert (flat_data[-num_padding:] == 0).all()
                # NOTE: this changes depending on the torch version, but we don't do a version
                # check since we might be trying to unshard an old checkpoint that was stored
                # with a different torch version than we're currently running with.
                if "flat_param._shard_indices" in handle:
                    # torch <=2.0.1
                    param_start = handle["flat_param._shard_indices"][0]
                    current_flat_index = 0
                    for relative_fqn, full_shape, (offset_start, offset_end) in zip(
                        handle["flat_param._fqns"][param_start:],
                        handle["flat_param._shapes"][param_start:],
                        handle["flat_param._shard_param_offsets"],
                    ):
                        root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
                        numel_shard = offset_end - offset_start + 1
                        flat_param_shard = _FlatParamShard(
                            full_shape=full_shape,
                            shard_offsets=(offset_start, offset_end),
                            shard_data=flat_data[current_flat_index : current_flat_index + numel_shard],
                        )
                        current_flat_index += numel_shard
                        yield root_fqn, flat_param_shard
                else:
                    # torch >=2.1.0
                    for relative_fqn, full_shape, shard_param_info in zip(
                        handle["flat_param._fqns"],
                        handle["flat_param._shapes"],
                        handle["flat_param._shard_param_infos"],
                    ):
                        if not shard_param_info.in_shard:
                            continue
                        root_fqn = relative_fqn if not module_prefix else f"{module_prefix}.{relative_fqn}"
                        flat_param_shard = _FlatParamShard(
                            full_shape=full_shape,
                            shard_offsets=(
                                shard_param_info.intra_param_start_idx,
                                shard_param_info.intra_param_end_idx,
                            ),
                            shard_data=flat_data[
                                shard_param_info.offset_in_shard : shard_param_info.offset_in_shard
                                + shard_param_info.numel_in_shard
                            ],
                        )
                        yield root_fqn, flat_param_shard

    def unshard_checkpoint(
        self,
        load_path: PathOrStr,
        *,
        local_cache: Optional[PathOrStr] = None,
        load_optimizer_state: bool = True,
        load_trainer_state: bool = True,
        device: Optional[torch.device] = None,
    ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, Any]], Optional[Dict[str, Any]]]:
        device = device or torch.device("cpu")
        metadata = self._load_metadata(load_path, local_cache=local_cache)

        # Gather paths model state, potentially downloading them.
        log.info("Gathering model state dicts...")
        model_state_paths = self._gather_state_dict_paths(
            load_path, "model", metadata.world_size, local_cache=local_cache
        )

        # Load model state dicts one-by-one, materializing and populating the full parameters as we go.
        log.info("Materializing full parameters...")
        full_model_state: Dict[str, torch.Tensor] = {}
        # We keep a copy of the flat param metadata minus the actual tensors so we can reconstruct
        # the full optimizer state below without having to reload the model state dicts.
        flat_params_data: Dict[int, Dict[str, _FlatParamShard]] = defaultdict(dict)
        for rank, path in enumerate(model_state_paths):
            log.info(f"Loading shards from rank {rank}...")
            model_state = torch.load(path, map_location="cpu")
            for root_fqn, flat_param_shard in self._iter_flat_param_shards(model_state):
                if root_fqn not in full_model_state:
                    log.info(
                        f"Materializing full parameter '{root_fqn}' with shape {flat_param_shard.full_shape}..."
                    )
                    assert flat_param_shard.shard_data is not None
                    full_model_state[root_fqn] = torch.empty(
                        flat_param_shard.full_shape, dtype=flat_param_shard.shard_data.dtype, device=device
                    )
                    # Fill with NaNs so we can validate that the whole parameter has been populated
                    # afterwards.
                    full_model_state[root_fqn].fill_(torch.nan)
                # Copy over the local shard to the relevant part of the full parameter.
                full_param = full_model_state[root_fqn]
                log.info(f"Loading rank {rank} shard for '{root_fqn}'...")
                flat_param_shard.copy_into(full_param)
                flat_params_data[rank][root_fqn] = replace(flat_param_shard, shard_data=None)

        log.info("Validating full parameters...")
        for key, tensor in full_model_state.items():
            if torch.isnan(tensor).any():
                raise ValueError(f"Parameter '{key}' contains NaNs, this is likely a bug with the unsharder")

        trainer_state: Optional[Dict[str, Any]] = None
        if load_trainer_state:
            trainer_state = load_state_dict(load_path, "train/rank0.pt", local_cache=local_cache)

        if not load_optimizer_state:
            return full_model_state, None, trainer_state

        log.info("Gathering optim state dicts...")
        optim_state_paths = self._gather_state_dict_paths(
            load_path, "optim", metadata.world_size, local_cache=local_cache
        )

        log.info("Materializing full optim state...")
        full_optim_state: Dict[str, Any] = {"state": defaultdict(dict)}
        fqn_to_id: Dict[str, int] = {}
        id_to_fqn: Dict[int, str] = {}
        for rank, path in enumerate(optim_state_paths):
            log.info(f"Loading sharded optim state from rank {rank}...")
            optim_state = torch.load(path, map_location="cpu")

            # Initialize param groups.
            # We assume parameter groups are the same across all ranks.
            # The only thing that differs across ranks is the state for each local sharded param.
            if "param_groups" not in full_optim_state:
                full_optim_state["param_groups"] = optim_state["param_groups"]
            else:
                assert full_optim_state["param_groups"] == optim_state["param_groups"]

            # Generate mapping of parameter FQNs to optimizer param IDs and vice-versa.
            if not fqn_to_id or not id_to_fqn:
                for group in full_optim_state["param_groups"]:
                    for fqn, id in zip(group["param_names"], group["params"]):
                        fqn = fqn.replace("_fsdp_wrapped_module.", "")
                        fqn_to_id[fqn] = id
                        id_to_fqn[id] = fqn

            # Iterate over local shard state and copy into the full state.
            for id, shard_state in optim_state["state"].items():
                fqn = id_to_fqn[id]
                flat_param_shard = flat_params_data[rank].get(fqn)  # type: ignore[assignment]
                full_state = full_optim_state["state"][id]
                for key, shard_value in shard_state.items():
                    assert isinstance(shard_value, torch.Tensor)
                    if shard_value.shape == torch.Size([]):
                        # Add singleton tensors directly to full state. These should be the same across
                        # all ranks.
                        assert key in ("step", "grad_norm_exp_avg")  # sanity check
                        if key not in full_state:
                            full_state[key] = shard_value.to(device)
                        else:
                            assert full_state[key] == shard_value
                    else:
                        # Otherwise we have a sharded param state.
                        # If the corresponding full param state hasn't been materialized yet, do so now.
                        assert flat_param_shard is not None, f"missing flat_params_data for {fqn} from rank {rank}"
                        if key not in full_state:
                            log.info(
                                f"Materializing full state '{key}' for '{fqn}' with shape {flat_param_shard.full_shape}..."
                            )
                            full_state[key] = torch.empty(
                                flat_param_shard.full_shape, dtype=shard_value.dtype, device=device
                            )
                        full_state_value = full_state[key]

                        # Copy over the local shard state to the relevant part of the full parameter state.
                        log.info(f"Loading rank {rank} shard state of '{key}' for '{fqn}'...")
                        replace(flat_param_shard, shard_data=shard_value).copy_into(full_state_value)

        # Lastly, clean up the parameter names in param groups.
        for group in full_optim_state["param_groups"]:
            group["param_names"] = [n.replace("_fsdp_wrapped_module.", "") for n in group["param_names"]]

        return full_model_state, full_optim_state, trainer_state

    def _get_state_dict_path(
        self,
        load_path: PathOrStr,
        state_dict_type: str,
        rank: int,
        *,
        local_cache: Optional[PathOrStr] = None,
        progress=None,
    ) -> Tuple[int, Path]:
        fname = f"{state_dict_type}/rank{rank}.pt"
        return rank, resource_path(str(load_path).rstrip("/"), fname, local_cache=local_cache, progress=progress)

    def _gather_state_dict_paths(
        self,
        load_path: PathOrStr,
        state_dict_type: str,
        world_size: int,
        *,
        local_cache: Optional[PathOrStr] = None,
    ) -> List[Path]:
        progress = get_progress_bar()
        with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
            futures = []
            for rank in range(world_size):
                future = executor.submit(
                    self._get_state_dict_path,
                    load_path,
                    state_dict_type,
                    rank,
                    local_cache=local_cache,
                    progress=progress,
                )
                futures.append(future)

            results: Dict[int, Path] = {}
            for future in as_completed(futures):
                rank, path = future.result()
                results[rank] = path

        return [results[rank] for rank in range(world_size)]


def build_sharded_checkpointer(
    cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
) -> Checkpointer:
    name = name or cfg.sharded_checkpointer
    if name == ShardedCheckpointerType.torch_new:
        return TorchNewStyleShardedCheckpointer(cfg)
    elif name == ShardedCheckpointerType.torch_legacy:
        return TorchLegacyShardedCheckpointer(cfg)
    elif name == ShardedCheckpointerType.local:
        return LocalShardedCheckpointer(cfg)
    else:
        raise NotImplementedError(name)