zero_to_fp32.py
32.5 KB · 761 lines · python Raw
1 #!/usr/bin/env python
2
3 # Copyright (c) Microsoft Corporation.
4 # SPDX-License-Identifier: Apache-2.0
5
6 # DeepSpeed Team
7
8 # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9 # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10 # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11 # application.
12 #
13 # example:
14 # python zero_to_fp32.py . output_dir/
15 # or
16 # python zero_to_fp32.py . output_dir/ --safe_serialization
17
18 import argparse
19 import torch
20 import glob
21 import math
22 import os
23 import re
24 import gc
25 import json
26 import numpy as np
27 from tqdm import tqdm
28 from collections import OrderedDict
29 from dataclasses import dataclass
30
31 # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32 # DeepSpeed data structures it has to be available in the current python environment.
33 from deepspeed.utils import logger
34 from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35 FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36 FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
38
39 @dataclass
40 class zero_model_state:
41 buffers: dict()
42 param_shapes: dict()
43 shared_params: list
44 ds_version: int
45 frozen_param_shapes: dict()
46 frozen_param_fragments: dict()
47
48
49 debug = 0
50
51 # load to cpu
52 device = torch.device('cpu')
53
54
55 def atoi(text):
56 return int(text) if text.isdigit() else text
57
58
59 def natural_keys(text):
60 '''
61 alist.sort(key=natural_keys) sorts in human order
62 http://nedbatchelder.com/blog/200712/human_sorting.html
63 (See Toothy's implementation in the comments)
64 '''
65 return [atoi(c) for c in re.split(r'(\d+)', text)]
66
67
68 def get_model_state_file(checkpoint_dir, zero_stage):
69 if not os.path.isdir(checkpoint_dir):
70 raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
72 # there should be only one file
73 if zero_stage <= 2:
74 file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75 elif zero_stage == 3:
76 file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
78 if not os.path.exists(file):
79 raise FileNotFoundError(f"can't find model states file at '{file}'")
80
81 return file
82
83
84 def get_checkpoint_files(checkpoint_dir, glob_pattern):
85 # XXX: need to test that this simple glob rule works for multi-node setup too
86 ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
88 if len(ckpt_files) == 0:
89 raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
91 return ckpt_files
92
93
94 def get_optim_files(checkpoint_dir):
95 return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
97
98 def get_model_state_files(checkpoint_dir):
99 return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
101
102 def parse_model_states(files):
103 zero_model_states = []
104 for file in files:
105 state_dict = torch.load(file, map_location=device, weights_only=False)
106
107 if BUFFER_NAMES not in state_dict:
108 raise ValueError(f"{file} is not a model state checkpoint")
109 buffer_names = state_dict[BUFFER_NAMES]
110 if debug:
111 print("Found buffers:", buffer_names)
112
113 # recover just the buffers while restoring them to fp32 if they were saved in fp16
114 buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115 param_shapes = state_dict[PARAM_SHAPES]
116
117 # collect parameters that are included in param_shapes
118 param_names = []
119 for s in param_shapes:
120 for name in s.keys():
121 param_names.append(name)
122
123 # update with frozen parameters
124 frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125 if frozen_param_shapes is not None:
126 if debug:
127 print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128 param_names += list(frozen_param_shapes.keys())
129
130 # handle shared params
131 shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
133 ds_version = state_dict.get(DS_VERSION, None)
134
135 frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
137 z_model_state = zero_model_state(buffers=buffers,
138 param_shapes=param_shapes,
139 shared_params=shared_params,
140 ds_version=ds_version,
141 frozen_param_shapes=frozen_param_shapes,
142 frozen_param_fragments=frozen_param_fragments)
143 zero_model_states.append(z_model_state)
144
145 return zero_model_states
146
147
148 def parse_optim_states(files, ds_checkpoint_dir):
149 total_files = len(files)
150 state_dicts = []
151 for f in tqdm(files, desc='Loading checkpoint shards'):
152 state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153 # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154 # and also handle the case where it was already removed by another helper script
155 state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156 state_dicts.append(state_dict)
157
158 if ZERO_STAGE not in state_dicts[0][OPTIMIZER_STATE_DICT]:
159 raise ValueError(f"{files[0]} is not a zero checkpoint")
160 zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161 world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
163 # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164 # parameters can be different from data parallelism for non-expert parameters. So we can just
165 # use the max of the partition_count to get the dp world_size.
166
167 if type(world_size) is list:
168 world_size = max(world_size)
169
170 if world_size != total_files:
171 raise ValueError(
172 f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173 "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174 )
175
176 # the groups are named differently in each stage
177 if zero_stage <= 2:
178 fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179 elif zero_stage == 3:
180 fp32_groups_key = FP32_FLAT_GROUPS
181 else:
182 raise ValueError(f"unknown zero stage {zero_stage}")
183
184 fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185 return zero_stage, world_size, fp32_flat_groups
186
187
188 def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189 """
190 Returns fp32 state_dict reconstructed from ds checkpoint
191
192 Args:
193 - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
195 """
196 print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
198 optim_files = get_optim_files(ds_checkpoint_dir)
199 zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200 print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
202 model_files = get_model_state_files(ds_checkpoint_dir)
203
204 zero_model_states = parse_model_states(model_files)
205 print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
207 if zero_stage <= 2:
208 return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209 exclude_frozen_parameters)
210 elif zero_stage == 3:
211 return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212 exclude_frozen_parameters)
213
214
215 def _zero2_merge_frozen_params(state_dict, zero_model_states):
216 if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217 return
218
219 frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220 frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
222 if debug:
223 num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224 print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
226 wanted_params = len(frozen_param_shapes)
227 wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228 avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229 print(f'Frozen params: Have {avail_numel} numels to process.')
230 print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
232 total_params = 0
233 total_numel = 0
234 for name, shape in frozen_param_shapes.items():
235 total_params += 1
236 unpartitioned_numel = shape.numel()
237 total_numel += unpartitioned_numel
238
239 state_dict[name] = frozen_param_fragments[name]
240
241 if debug:
242 print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
244 print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
246
247 def _has_callable(obj, fn):
248 attr = getattr(obj, fn, None)
249 return callable(attr)
250
251
252 def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253 param_shapes = zero_model_states[0].param_shapes
254
255 # Reconstruction protocol:
256 #
257 # XXX: document this
258
259 if debug:
260 for i in range(world_size):
261 for j in range(len(fp32_flat_groups[0])):
262 print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
264 # XXX: memory usage doubles here (zero2)
265 num_param_groups = len(fp32_flat_groups[0])
266 merged_single_partition_of_fp32_groups = []
267 for i in range(num_param_groups):
268 merged_partitions = [sd[i] for sd in fp32_flat_groups]
269 full_single_fp32_vector = torch.cat(merged_partitions, 0)
270 merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271 avail_numel = sum(
272 [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
274 if debug:
275 wanted_params = sum([len(shapes) for shapes in param_shapes])
276 wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277 # not asserting if there is a mismatch due to possible padding
278 print(f"Have {avail_numel} numels to process.")
279 print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
281 # params
282 # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283 # out-of-core computing solution
284 total_numel = 0
285 total_params = 0
286 for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287 offset = 0
288 avail_numel = full_single_fp32_vector.numel()
289 for name, shape in shapes.items():
290
291 unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292 total_numel += unpartitioned_numel
293 total_params += 1
294
295 if debug:
296 print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297 state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298 offset += unpartitioned_numel
299
300 # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301 # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302 # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303 # live optimizer object, so we are checking that the numbers are within the right range
304 align_to = 2 * world_size
305
306 def zero2_align(x):
307 return align_to * math.ceil(x / align_to)
308
309 if debug:
310 print(f"original offset={offset}, avail_numel={avail_numel}")
311
312 offset = zero2_align(offset)
313 avail_numel = zero2_align(avail_numel)
314
315 if debug:
316 print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
318 # Sanity check
319 if offset != avail_numel:
320 raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
322 print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
324
325 def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326 exclude_frozen_parameters):
327 state_dict = OrderedDict()
328
329 # buffers
330 buffers = zero_model_states[0].buffers
331 state_dict.update(buffers)
332 if debug:
333 print(f"added {len(buffers)} buffers")
334
335 if not exclude_frozen_parameters:
336 _zero2_merge_frozen_params(state_dict, zero_model_states)
337
338 _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
340 # recover shared parameters
341 for pair in zero_model_states[0].shared_params:
342 if pair[1] in state_dict:
343 state_dict[pair[0]] = state_dict[pair[1]]
344
345 return state_dict
346
347
348 def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349 remainder = unpartitioned_numel % world_size
350 padding_numel = (world_size - remainder) if remainder else 0
351 partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352 return partitioned_numel, padding_numel
353
354
355 def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356 if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357 return
358
359 if debug:
360 for i in range(world_size):
361 num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362 print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
364 frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365 wanted_params = len(frozen_param_shapes)
366 wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367 avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368 print(f'Frozen params: Have {avail_numel} numels to process.')
369 print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
371 total_params = 0
372 total_numel = 0
373 for name, shape in zero_model_states[0].frozen_param_shapes.items():
374 total_params += 1
375 unpartitioned_numel = shape.numel()
376 total_numel += unpartitioned_numel
377
378 param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379 state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
381 partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
383 if debug:
384 print(
385 f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386 )
387
388 print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
390
391 class GatheredTensor:
392 """
393 A pseudo tensor that collects partitioned weights.
394 It is more memory efficient when there are multiple groups.
395 """
396
397 def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398 self.flat_groups = flat_groups
399 self.flat_groups_offset = flat_groups_offset
400 self.offset = offset
401 self.partitioned_numel = partitioned_numel
402 self.shape = shape
403 self.dtype = self.flat_groups[0][0].dtype
404
405 def contiguous(self):
406 """
407 Merge partitioned weights from flat_groups into a single tensor.
408 """
409 end_idx = self.offset + self.partitioned_numel
410 world_size = len(self.flat_groups)
411 pad_flat_param_chunks = []
412
413 for rank_i in range(world_size):
414 # for each rank, we need to collect weights from related group/groups
415 flat_groups_at_rank_i = self.flat_groups[rank_i]
416 start_group_id = None
417 end_group_id = None
418 for group_id in range(len(self.flat_groups_offset)):
419 if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420 start_group_id = group_id
421 if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422 end_group_id = group_id
423 break
424 # collect weights from related group/groups
425 for group_id in range(start_group_id, end_group_id + 1):
426 flat_tensor = flat_groups_at_rank_i[group_id]
427 start_offset = self.offset - self.flat_groups_offset[group_id]
428 end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429 pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
431 # collect weights from all ranks
432 pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433 param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434 return param
435
436
437 def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438 param_shapes = zero_model_states[0].param_shapes
439 avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
441 # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442 # param, re-consolidating each param, while dealing with padding if any
443
444 # merge list of dicts, preserving order
445 param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
447 if debug:
448 for i in range(world_size):
449 print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
451 wanted_params = len(param_shapes)
452 wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453 # not asserting if there is a mismatch due to possible padding
454 avail_numel = fp32_flat_groups[0].numel() * world_size
455 print(f"Trainable params: Have {avail_numel} numels to process.")
456 print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
458 # params
459 # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460 # out-of-core computing solution
461 offset = 0
462 total_numel = 0
463 total_params = 0
464 flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465 for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466 unpartitioned_numel = shape.numel()
467 total_numel += unpartitioned_numel
468 total_params += 1
469 partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
471 if debug:
472 print(
473 f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474 )
475
476 # memory efficient tensor
477 tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478 state_dict[name] = tensor
479 offset += partitioned_numel
480
481 offset *= world_size
482
483 # Sanity check
484 if offset != avail_numel:
485 raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
487 print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
489
490 def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491 exclude_frozen_parameters):
492 state_dict = OrderedDict()
493
494 # buffers
495 buffers = zero_model_states[0].buffers
496 state_dict.update(buffers)
497 if debug:
498 print(f"added {len(buffers)} buffers")
499
500 if not exclude_frozen_parameters:
501 _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
503 _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
505 # recover shared parameters
506 for pair in zero_model_states[0].shared_params:
507 if pair[1] in state_dict:
508 state_dict[pair[0]] = state_dict[pair[1]]
509
510 return state_dict
511
512
513 def to_torch_tensor(state_dict, return_empty_tensor=False):
514 """
515 Convert state_dict of GatheredTensor to torch tensor
516 """
517 torch_state_dict = {}
518 converted_tensors = {}
519 for name, tensor in state_dict.items():
520 tensor_id = id(tensor)
521 if tensor_id in converted_tensors: # shared tensors
522 shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523 torch_state_dict[name] = shared_tensor
524 else:
525 converted_tensors[tensor_id] = name
526 if return_empty_tensor:
527 torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528 else:
529 torch_state_dict[name] = tensor.contiguous()
530 return torch_state_dict
531
532
533 def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534 tag=None,
535 exclude_frozen_parameters=False,
536 lazy_mode=False):
537 """
538 Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539 ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540 via a model hub.
541
542 Args:
543 - ``checkpoint_dir``: path to the desired checkpoint folder
544 - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545 - ``exclude_frozen_parameters``: exclude frozen parameters
546 - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547 Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
549 Returns:
550 - pytorch ``state_dict``
551
552 A typical usage might be ::
553
554 from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555 # do the training and checkpoint saving
556 state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557 model = model.cpu() # move to cpu
558 model.load_state_dict(state_dict)
559 # submit to model hub or save the model to share with others
560
561 In this example the ``model`` will no longer be usable in the deepspeed context of the same
562 application. i.e. you will need to re-initialize the deepspeed engine, since
563 ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
565 If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
567 Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568 You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569 the checkpoint. Or you can load state_dict in lazy mode ::
570
571 from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572 state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573 for name, lazy_tensor in state_dict.item():
574 tensor = lazy_tensor.contiguous() # to cpu
575 print(name, tensor)
576 # del tensor to release memory if it no longer in use
577 """
578 if tag is None:
579 latest_path = os.path.join(checkpoint_dir, 'latest')
580 if os.path.isfile(latest_path):
581 with open(latest_path, 'r') as fd:
582 tag = fd.read().strip()
583 else:
584 raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
586 ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
588 if not os.path.isdir(ds_checkpoint_dir):
589 raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
591 state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592 if lazy_mode:
593 return state_dict
594 else:
595 return to_torch_tensor(state_dict)
596
597
598 def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599 output_dir,
600 max_shard_size="5GB",
601 safe_serialization=False,
602 tag=None,
603 exclude_frozen_parameters=False):
604 """
605 Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606 loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
608 Args:
609 - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610 - ``output_dir``: directory to the pytorch fp32 state_dict output files
611 - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612 - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613 - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614 - ``exclude_frozen_parameters``: exclude frozen parameters
615 """
616
617 # Dependency pre-check
618 if safe_serialization:
619 try:
620 from safetensors.torch import save_file
621 except ImportError:
622 print('If you want to use `safe_serialization`, please `pip install safetensors`')
623 raise
624 if max_shard_size is not None:
625 try:
626 from huggingface_hub import split_torch_state_dict_into_shards
627 except ImportError:
628 print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629 raise
630
631 # Convert zero checkpoint to state_dict
632 state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633 tag,
634 exclude_frozen_parameters,
635 lazy_mode=True)
636
637 # Shard the model if it is too big.
638 weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639 if max_shard_size is not None:
640 filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641 # an memory-efficient approach for sharding
642 empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643 state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644 filename_pattern=filename_pattern,
645 max_shard_size=max_shard_size)
646 else:
647 from collections import namedtuple
648 StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649 state_dict_split = StateDictSplit(is_sharded=False,
650 filename_to_tensors={weights_name: list(state_dict.keys())})
651
652 # Save the model by shard
653 os.makedirs(output_dir, exist_ok=True)
654 filename_to_tensors = state_dict_split.filename_to_tensors.items()
655 for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656 shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657 shard_state_dict = to_torch_tensor(shard_state_dict)
658 output_path = os.path.join(output_dir, shard_file)
659 if safe_serialization:
660 save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661 else:
662 torch.save(shard_state_dict, output_path)
663 # release the memory of current shard
664 for tensor_name in list(shard_state_dict.keys()):
665 del state_dict[tensor_name]
666 del shard_state_dict[tensor_name]
667 del shard_state_dict
668 gc.collect()
669
670 # Save index if sharded
671 if state_dict_split.is_sharded:
672 index = {
673 "metadata": state_dict_split.metadata,
674 "weight_map": state_dict_split.tensor_to_filename,
675 }
676 save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677 save_index_file = os.path.join(output_dir, save_index_file)
678 with open(save_index_file, "w", encoding="utf-8") as f:
679 content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680 f.write(content)
681
682
683 def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684 """
685 1. Put the provided model to cpu
686 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687 3. Load it into the provided model
688
689 Args:
690 - ``model``: the model object to update
691 - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692 - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
694 Returns:
695 - ``model`: modified model
696
697 Make sure you have plenty of CPU memory available before you call this function. If you don't
698 have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699 conveniently placed for you in the checkpoint folder.
700
701 A typical usage might be ::
702
703 from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704 model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705 # submit to model hub or save the model to share with others
706
707 Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708 of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709 ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
711 """
712 logger.info("Extracting fp32 weights")
713 state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
715 logger.info("Overwriting model with fp32 weights")
716 model = model.cpu()
717 model.load_state_dict(state_dict, strict=False)
718
719 return model
720
721
722 if __name__ == "__main__":
723 parser = argparse.ArgumentParser()
724 parser.add_argument("checkpoint_dir",
725 type=str,
726 help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727 parser.add_argument("output_dir",
728 type=str,
729 help="directory to the pytorch fp32 state_dict output files"
730 "(e.g. path/checkpoint-12-output/)")
731 parser.add_argument(
732 "--max_shard_size",
733 type=str,
734 default="5GB",
735 help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736 "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737 "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738 "without CPU OOM issues.")
739 parser.add_argument(
740 "--safe_serialization",
741 default=False,
742 action='store_true',
743 help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744 parser.add_argument("-t",
745 "--tag",
746 type=str,
747 default=None,
748 help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749 parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750 parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751 args = parser.parse_args()
752
753 debug = args.debug
754
755 convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756 args.output_dir,
757 max_shard_size=args.max_shard_size,
758 safe_serialization=args.safe_serialization,
759 tag=args.tag,
760 exclude_frozen_parameters=args.exclude_frozen_parameters)
761