Skip to content

Commit 1f70803

Browse files
authored
Add max_shard_size arg (#6835)
* add max_shard_size arg * rm max_shard_size * move to 1024 --------- Co-authored-by: daisiming <daisiming@baidu.com>
1 parent be39ed1 commit 1f70803

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

paddlenlp/transformers/model_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def layer_prefix(key):
456456

457457
def shard_checkpoint(
458458
state_dict: Dict[str, paddle.Tensor],
459-
max_shard_size: Union[int, str] = "10GB",
459+
max_shard_size: Union[int, str] = "1024GB",
460460
weights_name: str = PADDLE_WEIGHTS_NAME,
461461
shard_format="naive",
462462
):
@@ -466,8 +466,8 @@ def shard_checkpoint(
466466
467467
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
468468
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
469-
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
470-
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
469+
limit is 1024GB and we have weights of sizes [600GB, 600GB, 200GB, 600GB, 200GB, 200GB] they will get sharded as [600GB], [600+200GB],
470+
[600+200+200GB] and not [600+200+200GB], [600+200GB], [600GB].
471471
472472
<Tip warning={true}>
473473
@@ -478,7 +478,7 @@ def shard_checkpoint(
478478
479479
Args:
480480
state_dict (`Dict[str, paddle.Tensor]`): The state dictionary of a model to save.
481-
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
481+
max_shard_size (`int` or `str`, *optional*, defaults to `"1024GB"`):
482482
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
483483
(like `"5MB"`).
484484
weights_name (`str`, *optional*, defaults to `"model_state.pdparams"`):
@@ -2122,7 +2122,7 @@ def save_pretrained(
21222122
is_main_process: bool = True,
21232123
state_dict: Optional[dict] = None,
21242124
save_function: Callable = paddle.save,
2125-
max_shard_size: Union[int, str] = "10GB",
2125+
max_shard_size: Union[int, str] = "1024GB",
21262126
safe_serialization: bool = False,
21272127
variant: Optional[str] = None,
21282128
*args,

0 commit comments

Comments
 (0)