From aedb8d1b222383faedb293927ba8a79f03197346 Mon Sep 17 00:00:00 2001 From: Pavel Belevich Date: Wed, 28 May 2025 11:55:40 -0400 Subject: [PATCH 1/2] Fix gather_state_dict_fast --- colossalai/checkpoint_io/utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 4b36dbe002bb..0fbd65d868fb 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1132,18 +1132,20 @@ def gather_state_dict_fast( if rank == dst: returned_state_dict = state_dict.copy() dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group) + ks, ops = [], [] for i, target_metadata in enumerate(all_meta_data): if i == dst: continue - ops = [] for k, shape, dtype in target_metadata: buffer = torch.empty(shape, dtype=dtype, device=get_current_device()) returned_state_dict[k] = buffer + ks.append(k) ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group)) - reqs = dist.batch_isend_irecv(ops) - for req, (k, *_) in zip(reqs, target_metadata): - req.wait() - returned_state_dict[k] = returned_state_dict[k].to(device) + reqs = dist.batch_isend_irecv(ops) + for req in reqs: # len(reqs) maybe be different from len(ops) because of coalescing + req.wait() + for k in ks: + returned_state_dict[k] = returned_state_dict[k].to(device) return returned_state_dict else: dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group) From dd562d28a45a25c9f1e98a7c9b51a76449e5dba0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 May 2025 16:06:50 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/checkpoint_io/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 0fbd65d868fb..ea97da1baeb1 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1142,7 +1142,7 @@ def gather_state_dict_fast( ks.append(k) ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group)) reqs = dist.batch_isend_irecv(ops) - for req in reqs: # len(reqs) maybe be different from len(ops) because of coalescing + for req in reqs: # len(reqs) maybe be different from len(ops) because of coalescing req.wait() for k in ks: returned_state_dict[k] = returned_state_dict[k].to(device)