Skip to content

Commit e78fe82

Browse files
committed
update async handler
1 parent 780040e commit e78fe82

File tree

7 files changed

+265
-226
lines changed

7 files changed

+265
-226
lines changed
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Asynchronous unified checkpoint handler."""
15+
16+
import multiprocessing
17+
import os
18+
import time
19+
from multiprocessing import shared_memory
20+
21+
import paddle
22+
import paddle.distributed as dist
23+
24+
from paddlenlp.transformers.utils import is_safetensors_available
25+
from paddlenlp.utils.log import logger
26+
27+
if is_safetensors_available():
28+
from safetensors.numpy import save_file as safe_save_file
29+
30+
from .shared_memory_utils import (
31+
_read_state_dict_from_shm,
32+
_traverse_copy_to_shm,
33+
create_meta_dict,
34+
)
35+
36+
37+
class AsyncCheckpointHander:
38+
def __init__(self, args):
39+
# Mainly for asynchronous saving.
40+
self.args = args
41+
self.global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
42+
43+
self._shm_model_weight = None
44+
self._shm_master_weight = None
45+
self._shm_optimizer_weight = None
46+
self._meta_dict_model = None
47+
self._meta_dict_master_weight = None
48+
self._meta_dict_optim = None
49+
self._process_model_weight = None
50+
self._process_master_weight = None
51+
self._process_optimizer_weight = None
52+
self._lock = None
53+
self._shared_save_model_flag = None
54+
self._shared_save_master_weight_flag = None
55+
self._shared_save_optimizer_flag = None
56+
57+
if "async_save" in self.args.unified_checkpoint_config:
58+
self._lock = multiprocessing.Lock()
59+
self._shared_save_model_path = multiprocessing.Array("c", 100000)
60+
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
61+
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
62+
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
63+
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
64+
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
65+
self._shared_save_model_flag = multiprocessing.Array("i", 1)
66+
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
67+
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)
68+
69+
def _file_save_async_or_sync(
70+
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
71+
):
72+
if is_sync:
73+
for k in list(state_dict.keys()):
74+
if isinstance(state_dict[k], paddle.Tensor):
75+
state_dict[k] = state_dict.pop(k).cpu().numpy()
76+
safe_save_file(state_dict, path, metadata={"format": "np"})
77+
else:
78+
if state_dict_type == "model_weight":
79+
if self._shm_model_weight is None:
80+
self._meta_dict_model, buffer_size = create_meta_dict(state_dict)
81+
self._shm_model_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
82+
shm_state_dict = self._shm_model_weight
83+
meta_dict = self._meta_dict_model
84+
shared_save_flag = self._shared_save_model_flag
85+
shared_save_path = self._shared_save_model_path
86+
shared_save_signal_path = self._shared_save_model_signal_path
87+
if self._process_model_weight is None:
88+
self._process_model_weight = multiprocessing.Process(
89+
target=self._save_file_async_in_process,
90+
args=(
91+
meta_dict,
92+
self._shm_model_weight.name,
93+
self._shared_save_model_flag,
94+
self._shared_save_model_path,
95+
self._shared_save_model_signal_path,
96+
self._lock,
97+
state_dict_type,
98+
self.global_rank,
99+
),
100+
)
101+
self._process_model_weight.start()
102+
process = self._process_model_weight
103+
elif state_dict_type == "master_weight":
104+
if self._shm_master_weight is None:
105+
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
106+
self._shm_master_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
107+
shm_state_dict = self._shm_master_weight
108+
meta_dict = self._meta_dict_master_weight
109+
shared_save_flag = self._shared_save_master_weight_flag
110+
shared_save_path = self._shared_save_master_weight_path
111+
shared_save_signal_path = self._shared_save_master_weight_signal_path
112+
if self._process_master_weight is None:
113+
self._process_master_weight = multiprocessing.Process(
114+
target=self._save_file_async_in_process,
115+
args=(
116+
meta_dict,
117+
self._shm_master_weight.name,
118+
self._shared_save_master_weight_flag,
119+
self._shared_save_master_weight_path,
120+
self._shared_save_master_weight_signal_path,
121+
self._lock,
122+
"model_weight"
123+
if "skip_save_model_weight" in self.args.unified_checkpoint_config
124+
else state_dict_type,
125+
self.global_rank,
126+
),
127+
)
128+
self._process_master_weight.start()
129+
process = self._process_master_weight
130+
elif state_dict_type == "optimizer_weight":
131+
if self._shm_optimizer_weight is None:
132+
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
133+
self._shm_optimizer_weight = shared_memory.SharedMemory(create=True, size=buffer_size)
134+
shm_state_dict = self._shm_optimizer_weight
135+
meta_dict = self._meta_dict_optim
136+
shared_save_flag = self._shared_save_optimizer_flag
137+
shared_save_path = self._shared_save_optimizer_path
138+
shared_save_signal_path = self._shared_save_optimizer_signal_path
139+
if self._process_optimizer_weight is None:
140+
self._process_optimizer_weight = multiprocessing.Process(
141+
target=self._save_file_async_in_process,
142+
args=(
143+
meta_dict,
144+
self._shm_optimizer_weight.name,
145+
self._shared_save_optimizer_flag,
146+
self._shared_save_optimizer_path,
147+
self._shared_save_optimizer_signal_path,
148+
self._lock,
149+
state_dict_type,
150+
self.global_rank,
151+
),
152+
)
153+
self._process_optimizer_weight.start()
154+
process = self._process_optimizer_weight
155+
156+
while True: # wait until no process is saving.
157+
flag_value = shared_save_flag[0]
158+
if flag_value == 0:
159+
break
160+
if not process.is_alive():
161+
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
162+
time.sleep(0.5)
163+
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
164+
# only save model weight or save master weight, we enter this loop.
165+
self._reset_and_update(shared_save_path, path)
166+
self._reset_and_update(shared_save_signal_path, signal_path)
167+
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
168+
with self._lock:
169+
shared_save_flag[0] = 1
170+
171+
def _save_file_async_in_process(
172+
self,
173+
meta_dict,
174+
shm_name,
175+
shared_save_flag,
176+
shared_save_path,
177+
shared_save_signal_path,
178+
lock,
179+
state_dict_type,
180+
global_rank,
181+
):
182+
shm = shared_memory.SharedMemory(name=shm_name)
183+
while True:
184+
flag_value = shared_save_flag[0] # if process uses `spawn`, cannot read this value.
185+
if flag_value == -1: # stop process
186+
break
187+
if flag_value == 0: # nothing to save
188+
continue
189+
if flag_value == 1: # need to save
190+
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
191+
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
192+
logger.info(f"Start to async save {path}")
193+
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
194+
safe_save_file(state_dict, path, {"format": "np"})
195+
del state_dict
196+
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
197+
paddle.save(global_rank, saved_signal_path)
198+
with lock:
199+
shared_save_flag[0] = 0
200+
time.sleep(0.5)
201+
shm.close()
202+
203+
def _reset_and_update(self, shared_array, new_value):
204+
# clear array
205+
for i in range(len(shared_array)):
206+
shared_array[i] = b"\0"
207+
# update array
208+
encoded_value = new_value.encode("utf-8")
209+
shared_array[: len(encoded_value)] = encoded_value
210+
211+
def unlink_shared_memory(self):
212+
if not ("async_save" in self.args.unified_checkpoint_config):
213+
return
214+
215+
if self._shared_save_model_flag is not None:
216+
while self._shared_save_model_flag[0] > 0: # async process is saving
217+
if not self._process_model_weight.is_alive():
218+
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
219+
time.sleep(0.5)
220+
self._shared_save_model_flag[0] = -1
221+
if self._shared_save_master_weight_flag is not None:
222+
while self._shared_save_master_weight_flag[0] > 0:
223+
if not self._process_master_weight.is_alive():
224+
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
225+
time.sleep(0.5)
226+
self._shared_save_master_weight_flag[0] = -1
227+
if self._shared_save_optimizer_flag is not None:
228+
while self._shared_save_optimizer_flag[0] > 0:
229+
if not self._process_optimizer_weight.is_alive():
230+
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
231+
time.sleep(0.5)
232+
self._shared_save_optimizer_flag[0] = -1
233+
234+
if self._shm_model_weight is not None:
235+
self._shm_model_weight.close()
236+
self._shm_model_weight.unlink()
237+
self._shm_model_weight = None
238+
if self._shm_master_weight is not None:
239+
self._shm_master_weight.close()
240+
self._shm_master_weight.unlink()
241+
self._shm_master_weight = None
242+
if self._shm_optimizer_weight is not None:
243+
self._shm_optimizer_weight.close()
244+
self._shm_optimizer_weight.unlink()
245+
self._shm_optimizer_weight = None
246+
247+
if paddle.distributed.get_world_size() > 1:
248+
dist.barrier()

paddlenlp/trainer/unified_checkpoint/check_uc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
update_master_weight_status,
4343
)
4444

45+
__all__ = ["check_unified_checkpoint", "check_unified_optimizer"]
46+
4547

4648
def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
4749
index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False)

paddlenlp/trainer/unified_checkpoint/uc_dynamic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
update_master_weight_status,
5656
)
5757

58+
__all__ = ["load_unified_checkpoint_dynamically", "load_unified_optimizer_dynamically"]
59+
5860

5961
def create_send_table(file_keyname_mappings, file_machine_mappings):
6062
send_table = {}
@@ -258,7 +260,7 @@ def distributed_send_recv(
258260
return state_dict
259261

260262

261-
def load_uc_dynamically(args, model, resume_from_checkpoint, safe_serialization=False):
263+
def load_unified_checkpoint_dynamically(args, model, resume_from_checkpoint, safe_serialization=False):
262264
index_filename = select_model_weight_index(model, resume_from_checkpoint, safe_serialization, local=False)
263265
index_filename = os.path.join(resume_from_checkpoint, index_filename)
264266

paddlenlp/trainer/unified_checkpoint/uc_locally_load.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
update_master_weight_status,
5252
)
5353

54+
__all__ = ["load_unified_checkpoint_locally", "load_unified_optimizer_locally"]
55+
5456

5557
def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):
5658
"""

paddlenlp/trainer/unified_checkpoint/uc_sharding_v2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
mapping_optimizer_tp_actions,
3838
)
3939

40+
__all__ = ["gather_splited_param_for_optimizer", "load_unified_optimizer_split_param"]
41+
4042

4143
def merge_splited_param(
4244
state_dict, partial_tensor_list, param_shape_info, send_table, recv_table, is_master_weights=False

paddlenlp/trainer/unified_checkpoint/uc_single_card.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@
5454
save_model_config,
5555
)
5656

57+
__all__ = [
58+
"load_single_card_checkpoint",
59+
"load_single_card_optimizer",
60+
"save_single_card_checkpoint",
61+
"save_single_card_optimizer",
62+
]
63+
5764

5865
def save_file_sync(state_dict, path):
5966
for k in list(state_dict.keys()):

0 commit comments

Comments
 (0)