Skip to content

Commit c9d5673

Browse files
DesmonDaygongel
andauthored
[Unified Checkpoint] Add split param and refactor code (#9240)
* [Unified checkpoint] update optimizer async save signal * update paddlepaddle * split param * add save for split param * fix save split_param * add load uc split_param * update uc files * update uc files * update split_param loading * mkdir unified_checkpoint directory * rename file * update async handler * update files --------- Co-authored-by: gongenlei <gongenlei@baidu.com>
1 parent 81f5ab5 commit c9d5673

15 files changed

+3236
-2575
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 0 additions & 2569 deletions
This file was deleted.

paddlenlp/trainer/trainer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@
113113
from .argparser import strtobool
114114
from .integrations import get_reporting_integration_callbacks
115115
from .plugins.timer import RuntimeTimer, get_timers, set_timers
116-
from .plugins.unified_checkpoint import UnifiedCheckpointHandler
117116
from .trainer_callback import (
118117
CallbackHandler,
119118
DefaultFlowCallback,
@@ -144,6 +143,7 @@
144143
speed_metrics,
145144
)
146145
from .training_args import TrainingArguments
146+
from .unified_checkpoint import UnifiedCheckpointHandler
147147
from .utils import reshard as reshard_util
148148
from .utils.async_save import AsyncSaver
149149
from .utils.helper import ( # nested_truncate,
@@ -598,7 +598,6 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
598598
if use_unified_checkpoint:
599599
self.unified_checkpoint_handler.load_unified_checkpoint(
600600
self.model,
601-
self.optimizer,
602601
resume_from_checkpoint,
603602
)
604603
logger.info(f"Loading model from {resume_from_checkpoint} using unified checkpoint.")
@@ -1241,7 +1240,6 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
12411240
if self.args.unified_checkpoint:
12421241
self.unified_checkpoint_handler.load_unified_checkpoint(
12431242
self.model,
1244-
self.optimizer,
12451243
self.state.best_model_checkpoint,
12461244
)
12471245
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
@@ -1289,7 +1287,6 @@ def _load_best_model_from_peft_checkpoint(self):
12891287
if self.args.unified_checkpoint:
12901288
self.unified_checkpoint_handler.load_unified_checkpoint(
12911289
self.model,
1292-
self.optimizer,
12931290
self.state.best_model_checkpoint,
12941291
)
12951292
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
@@ -2775,7 +2772,6 @@ def _load_optimizer_and_scheduler(self, checkpoint):
27752772
opt_state_dict = None
27762773
else:
27772774
opt_state_dict = self.unified_checkpoint_handler.load_unified_optimizer(
2778-
args=self.args,
27792775
model=self.model,
27802776
optimizer=self.optimizer,
27812777
resume_from_checkpoint=checkpoint,

paddlenlp/trainer/training_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,12 @@ def is_segment_parallel_supported():
14021402
f"but got logging_steps={self.logging_steps}."
14031403
)
14041404

1405+
if "split_param" in sharding_parallel_config:
1406+
assert self.sharding == [ShardingOption.SHARD_OP], "Only sharding stage1 support split_param."
1407+
assert (
1408+
self.amp_master_grad
1409+
), "If `split_param` in sharding_parallel_config, `amp_master_grad` must be True."
1410+
14051411
fleet.init(is_collective=True, strategy=strategy)
14061412
logger.info(strategy)
14071413

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .unified_checkpoint import UnifiedCheckpointHandler
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Asynchronous unified checkpoint handler."""
15+
16+
import multiprocessing
17+
import os
18+
import time
19+
from multiprocessing import shared_memory
20+
21+
import paddle
22+
import paddle.distributed as dist
23+
24+
from paddlenlp.transformers.utils import is_safetensors_available
25+
from paddlenlp.utils.log import logger
26+
27+
if is_safetensors_available():
28+
from safetensors.numpy import save_file as safe_save_file
29+
30+
from .shared_memory_utils import (
31+
_read_state_dict_from_shm,
32+
_traverse_copy_to_shm,
33+
create_meta_dict,
34+
)
35+
36+
__all__ = ["AsyncCheckpointHandler"]
37+
38+
39+
class AsyncCheckpointHandler:
40+
def __init__(self, args):
41+
# Mainly for asynchronous saving.
42+
self.args = args
43+
self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
44+
45+
self._shm_model_weight = None
46+
self._shm_master_weight = None
47+
self._shm_optimizer_weight = None
48+
self._meta_dict_model = None
49+
self._meta_dict_master_weight = None
50+
self._meta_dict_optim = None
51+
self._process_model_weight = None
52+
self._process_master_weight = None
53+
self._process_optimizer_weight = None
54+
self._lock = None
55+
self._shared_save_model_flag = None
56+
self._shared_save_master_weight_flag = None
57+
self._shared_save_optimizer_flag = None
58+
59+
if "async_save" in self.args.unified_checkpoint_config:
60+
self._lock = multiprocessing.Lock()
61+
self._shared_save_model_path = multiprocessing.Array("c", 100000)
62+
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
63+
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
64+
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
65+
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
66+
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
67+
self._shared_save_model_flag = multiprocessing.Array("i", 1)
68+
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
69+
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)
70+
71+
def _file_save_async_or_sync(
72+
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
73+
):
74+
if is_sync:
75+
for k in list(state_dict.keys()):
76+
if isinstance(state_dict[k], paddle.Tensor):
77+
state_dict[k] = state_dict.pop(k).cpu().numpy()
78+
safe_save_file(state_dict, path, metadata={"format": "np"})
79+
else:
80+
if state_dict_type == "model_weight":
81+
if self._shm_model_weight is None:
82+
self._meta_dict_model, buffer_size = create_meta_dict(state_dict)
83+
self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
84+
shm_state_dict = self._shm_model_weight
85+
meta_dict = self._meta_dict_model
86+
shared_save_flag = self._shared_save_model_flag
87+
shared_save_path = self._shared_save_model_path
88+
shared_save_signal_path = self._shared_save_model_signal_path
89+
if self._process_model_weight is None:
90+
self._process_model_weight = multiprocessing.Process(
91+
target=self._save_file_async_in_process,
92+
args=(
93+
meta_dict,
94+
self._shm_model_weight.name,
95+
self._shared_save_model_flag,
96+
self._shared_save_model_path,
97+
self._shared_save_model_signal_path,
98+
self._lock,
99+
state_dict_type,
100+
self.global_rank,
101+
),
102+
)
103+
self._process_model_weight.start()
104+
process = self._process_model_weight
105+
elif state_dict_type == "master_weight":
106+
if self._shm_master_weight is None:
107+
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
108+
self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
109+
shm_state_dict = self._shm_master_weight
110+
meta_dict = self._meta_dict_master_weight
111+
shared_save_flag = self._shared_save_master_weight_flag
112+
shared_save_path = self._shared_save_master_weight_path
113+
shared_save_signal_path = self._shared_save_master_weight_signal_path
114+
if self._process_master_weight is None:
115+
self._process_master_weight = multiprocessing.Process(
116+
target=self._save_file_async_in_process,
117+
args=(
118+
meta_dict,
119+
self._shm_master_weight.name,
120+
self._shared_save_master_weight_flag,
121+
self._shared_save_master_weight_path,
122+
self._shared_save_master_weight_signal_path,
123+
self._lock,
124+
"model_weight"
125+
if "skip_save_model_weight" in self.args.unified_checkpoint_config
126+
else state_dict_type,
127+
self.global_rank,
128+
),
129+
)
130+
self._process_master_weight.start()
131+
process = self._process_master_weight
132+
elif state_dict_type == "optimizer_weight":
133+
if self._shm_optimizer_weight is None:
134+
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
135+
self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
136+
shm_state_dict = self._shm_optimizer_weight
137+
meta_dict = self._meta_dict_optim
138+
shared_save_flag = self._shared_save_optimizer_flag
139+
shared_save_path = self._shared_save_optimizer_path
140+
shared_save_signal_path = self._shared_save_optimizer_signal_path
141+
if self._process_optimizer_weight is None:
142+
self._process_optimizer_weight = multiprocessing.Process(
143+
target=self._save_file_async_in_process,
144+
args=(
145+
meta_dict,
146+
self._shm_optimizer_weight.name,
147+
self._shared_save_optimizer_flag,
148+
self._shared_save_optimizer_path,
149+
self._shared_save_optimizer_signal_path,
150+
self._lock,
151+
state_dict_type,
152+
self.global_rank,
153+
),
154+
)
155+
self._process_optimizer_weight.start()
156+
process = self._process_optimizer_weight
157+
158+
while True: # wait until no process is saving.
159+
flag_value = shared_save_flag[0]
160+
if flag_value == 0:
161+
break
162+
if not process.is_alive():
163+
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
164+
time.sleep(0.5)
165+
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
166+
# only save model weight or save master weight, we enter this loop.
167+
self._reset_and_update(shared_save_path, path)
168+
self._reset_and_update(shared_save_signal_path, signal_path)
169+
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
170+
with self._lock:
171+
shared_save_flag[0] = 1
172+
173+
def _save_file_async_in_process(
174+
self,
175+
meta_dict,
176+
shm_name,
177+
shared_save_flag,
178+
shared_save_path,
179+
shared_save_signal_path,
180+
lock,
181+
state_dict_type,
182+
global_rank,
183+
):
184+
shm = shared_memory.SharedMemory(name=shm_name)
185+
while True:
186+
flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value.
187+
if flag_value == -1: # stop process
188+
break
189+
if flag_value == 0: # nothing to save
190+
continue
191+
if flag_value == 1: # need to save
192+
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
193+
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
194+
logger.info(f"Start to async save {path}")
195+
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
196+
safe_save_file(state_dict, path, {"format": "np"})
197+
del state_dict
198+
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
199+
paddle.save(global_rank, saved_signal_path)
200+
with lock:
201+
shared_save_flag[0] = 0
202+
time.sleep(0.5)
203+
shm.close()
204+
205+
def _reset_and_update(self, shared_array, new_value):
206+
# clear array
207+
for i in range(len(shared_array)):
208+
shared_array[i] = b"\0"
209+
# update array
210+
encoded_value = new_value.encode("utf-8")
211+
shared_array[: len(encoded_value)] = encoded_value
212+
213+
def unlink_shared_memory(self):
214+
if not ("async_save" in self.args.unified_checkpoint_config):
215+
return
216+
217+
if self._shared_save_model_flag is not None:
218+
while self._shared_save_model_flag[0] > 0: # async process is saving
219+
if not self._process_model_weight.is_alive():
220+
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
221+
time.sleep(0.5)
222+
self._shared_save_model_flag[0] = -1
223+
if self._shared_save_master_weight_flag is not None:
224+
while self._shared_save_master_weight_flag[0] > 0:
225+
if not self._process_master_weight.is_alive():
226+
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
227+
time.sleep(0.5)
228+
self._shared_save_master_weight_flag[0] = -1
229+
if self._shared_save_optimizer_flag is not None:
230+
while self._shared_save_optimizer_flag[0] > 0:
231+
if not self._process_optimizer_weight.is_alive():
232+
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
233+
time.sleep(0.5)
234+
self._shared_save_optimizer_flag[0] = -1
235+
236+
if self._shm_model_weight is not None:
237+
self._shm_model_weight.close()
238+
self._shm_model_weight.unlink()
239+
self._shm_model_weight = None
240+
if self._shm_master_weight is not None:
241+
self._shm_master_weight.close()
242+
self._shm_master_weight.unlink()
243+
self._shm_master_weight = None
244+
if self._shm_optimizer_weight is not None:
245+
self._shm_optimizer_weight.close()
246+
self._shm_optimizer_weight.unlink()
247+
self._shm_optimizer_weight = None
248+
249+
if paddle.distributed.get_world_size() > 1:
250+
dist.barrier()

0 commit comments

Comments
 (0)