Skip to content

Commit 730a762

Browse files
authored
[LLM] support flash device on static model (#9619) (#9787)
* [LLM] support flash device on static model * [LLM] adapt pdc sdk
1 parent fb3e4c0 commit 730a762

File tree

2 files changed

+156
-9
lines changed

2 files changed

+156
-9
lines changed

paddlenlp/utils/downloader.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from collections import OrderedDict
2525
from typing import Optional, Union
2626

27+
import paddle.distributed as dist
2728
import requests
2829
from filelock import FileLock
2930
from huggingface_hub import get_hf_file_metadata, hf_hub_url
@@ -33,7 +34,13 @@
3334
from .env import DOWNLOAD_SERVER, FAILED_STATUS, SUCCESS_STATUS
3435
from .fault_tolerance import PDC_DOWNLOAD_ERROR
3536
from .log import logger
36-
from .pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool
37+
from .pdc_sdk import (
38+
FLASH_DEVICE,
39+
PDCErrorCode,
40+
PDCErrorMessageMap,
41+
pdc_flash_device_available,
42+
pdc_tool,
43+
)
3744

3845
__all__ = ["get_weights_path_from_url"]
3946

@@ -487,7 +494,7 @@ def download_from_pdc(remote_path, local_path, timeout):
487494
"""
488495

489496
try:
490-
base_dir, _ = os.path.split(os.path.normpath(remote_path))
497+
base_dir, _ = os.path.split(os.path.normpath(local_path))
491498
if not os.path.exists(base_dir) and base_dir != "":
492499
os.makedirs(base_dir, exist_ok=True)
493500
except Exception as e:
@@ -505,3 +512,81 @@ def download_from_pdc(remote_path, local_path, timeout):
505512
raise RuntimeError(
506513
f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download object from PDC, remote_path: {remote_path}, local_path: {local_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}"
507514
)
515+
516+
517+
def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_device=False):
518+
"""
519+
Get static model from PDC. Use flash device if possible.
520+
This function has to be called after distributed env is initialized in distributed mode.
521+
Args:
522+
remote_path (`str`):
523+
remote path url for download
524+
local_path (`str`):
525+
local path to place downloaded object
526+
timeout (`int`):
527+
max wait time for download
528+
enable_flash_device (`bool`):
529+
Whether to use flash device
530+
Returns:
531+
str: path to load static model
532+
"""
533+
try:
534+
base_dir, target_dir = os.path.split(os.path.normpath(local_path))
535+
if not os.path.exists(base_dir) and base_dir != "":
536+
os.makedirs(base_dir, exist_ok=True)
537+
except Exception as e:
538+
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")
539+
540+
assert target_dir != ".", f"{PDC_DOWNLOAD_ERROR}, illegal local_path: {local_path}."
541+
542+
flash_path = os.path.join(FLASH_DEVICE, target_dir)
543+
persistent_path = local_path
544+
545+
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
546+
if device_id != 0:
547+
logger.info("Waiting local process 0...")
548+
dist.barrier()
549+
return flash_path if (enable_flash_device and os.path.exists(flash_path)) else persistent_path
550+
551+
# step 1: load from flash device if possible
552+
need_download_from_remote = True
553+
need_backup_to_flash = False
554+
if enable_flash_device and pdc_flash_device_available():
555+
logger.info(f"flash device is available, checking status on {flash_path}...")
556+
# skip download SC as default when flash device is available
557+
need_download_from_remote = False
558+
if os.path.exists(flash_path) and pdc_tool.pdc_flash_do_check(flash_path) == PDCErrorCode.Success:
559+
logger.info("Static model checked successfully on flash device, ready to load...")
560+
else:
561+
logger.warning(
562+
"flash device is available but no valid static model found on flash device, need to download from remote."
563+
)
564+
need_download_from_remote = True
565+
need_backup_to_flash = True
566+
else:
567+
logger.info("Flash device is not enabled or available, will download static model from remote.")
568+
569+
# step 2: download from remote if neccesary
570+
if need_download_from_remote:
571+
logger.info("Beging download static model from remote...")
572+
download_from_pdc(remote_path, persistent_path, timeout)
573+
logger.info(f"downloaded static model from remote, path:{persistent_path}")
574+
575+
# step 3: backup to flash device if flash device is available
576+
if enable_flash_device and need_backup_to_flash:
577+
result = pdc_tool.pdc_backup_to_flash_device(persistent_path, flash_path)
578+
if result == PDCErrorCode.Success:
579+
logger.info(f"Backup static model to flash device {flash_path} successfully.")
580+
else:
581+
logger.error(f"Backup static model to flash device failed, error details: {PDCErrorMessageMap[result]}.")
582+
583+
# step 4: return flash path if available, otherwise return persistent path
584+
if dist.get_world_size() > 1:
585+
logger.info("Local node process done, waiting other nodes...")
586+
dist.barrier()
587+
if enable_flash_device and os.path.exists(flash_path):
588+
logger.info(f"static model is ready on flash device, path: {flash_path}")
589+
return flash_path
590+
else:
591+
logger.info(f"static model is only ready on persistent storage, path: {persistent_path}")
592+
return persistent_path

paddlenlp/utils/pdc_sdk.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import json
1616
import os
1717
import queue
18+
import shutil
1819
import subprocess
1920
import threading
2021
import time
22+
from distutils.dir_util import copy_tree
2123
from enum import Enum
2224
from typing import List
2325

@@ -28,6 +30,13 @@
2830
TRAIN_CONFIG = "/root/paddlejob/workspace/env_run/longjob/train.conf"
2931
TAR_BIN = "tar"
3032

33+
FLASH_DEVICE = os.getenv("PDC_FLASH_DEVICE", "/shared/dev/shm/flash")
34+
35+
36+
def pdc_flash_device_available():
37+
# TODO(@gexiao): need better check
38+
return os.path.exists(FLASH_DEVICE)
39+
3140

3241
class PDCErrorCode(Enum):
3342
"""Error Code For PDCTools usage"""
@@ -48,6 +57,7 @@ class PDCErrorCode(Enum):
4857
InvalidArgument = 1503
4958
CommandTimeout = 1504
5059
CheckSumCommandFail = 1505
60+
CopyTreeFailed = 1506
5161

5262
UnknownError = 1999
5363

@@ -493,14 +503,60 @@ def _download_file(self, remote_path: str, local_path: str) -> PDCErrorCode:
493503
raise Exception(f"exec cmd {download_cmd_args} with error: {e}")
494504
return error_code
495505

496-
def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode:
506+
def _pdc_backup_failed_directory(self, path):
507+
base_dir, target_path = os.path.split(os.path.normpath(path))
508+
failed_path = os.path.join(base_dir, f"{target_path}_failed")
509+
if os.path.exists(path):
510+
if os.path.exists(failed_path):
511+
shutil.rmtree(failed_path)
512+
# Backup failed files for debug
513+
os.rename(path, failed_path)
514+
515+
def pdc_backup_to_flash_device(self, persistent_path: str, flash_device_path: str) -> PDCErrorCode:
516+
"""backup data to flash device
517+
518+
Args:
519+
persistent_path str: persistent path
520+
flash_device_path str: flash device path
521+
"""
522+
if not os.path.exists(persistent_path):
523+
logger.error(f"{persistent_path} not exist")
524+
return PDCErrorCode.LocalPathNotExist
525+
526+
logger.info("starting backup to flash device...")
527+
528+
# step 1: generate checksum for recovery
529+
result = self.pdc_generate_dir_checksum(persistent_path)
530+
if result != PDCErrorCode.Success:
531+
logger.error(f"[Error] [pdc_sdk] generating checksum for {persistent_path} failed")
532+
return result
533+
534+
# step 2: copy persistent data to flash device
535+
try:
536+
copy_tree(persistent_path, flash_device_path)
537+
logger.info(f"backup {persistent_path} to {flash_device_path} successed.")
538+
except Exception as e:
539+
logger.error(f"[Error] [pdc_sdk] copy tree {persistent_path} to {flash_device_path} failed, error: {e}")
540+
self._pdc_backup_failed_directory(flash_device_path)
541+
return PDCErrorCode.CopyTreeFailed
542+
543+
# step 3: do checksum for storage on flash device
544+
result = self.pdc_flash_do_check(flash_device_path)
545+
if result == PDCErrorCode.Success:
546+
return result
547+
548+
logger.error(f"[Error] [pdc_sdk] checksum failed on {flash_device_path} after copy, backup for debug")
549+
self._pdc_backup_failed_directory(flash_device_path)
550+
return result
551+
552+
def pdc_generate_dir_checksum(self, path: str) -> PDCErrorCode:
497553
"""
498554
Args
499555
:param localPath:
500556
:return:
501557
"""
502558
if not os.path.exists(path):
503-
logger.error(f"pdc_fc_generate_checksum gi{path} not exist")
559+
logger.error(f"pdc_generate_dir_checksum gi{path} not exist")
504560
return PDCErrorCode.CommandFail
505561
generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "generateSum", "-path", f"{path}"]
506562
error_code = PDCErrorCode.Success
@@ -514,14 +570,14 @@ def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode:
514570
return PDCErrorCode.CheckSumCommandFail
515571
return error_code
516572

517-
def pdc_fc_do_check(self, path: str) -> PDCErrorCode:
573+
def pdc_flash_do_check(self, path: str) -> PDCErrorCode:
518574
"""
519575
Args
520576
:param localPath:
521577
:return:
522578
"""
523579
if not os.path.exists(path):
524-
logger.error(f"pdc_fc_do_check {path} not exist")
580+
logger.error(f"pdc_flash_do_check {path} not exist")
525581
return PDCErrorCode.CommandFail
526582
generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "checkSum", "-path", f"{path}"]
527583
error_code = PDCErrorCode.Success
@@ -530,8 +586,12 @@ def pdc_fc_do_check(self, path: str) -> PDCErrorCode:
530586
res, error_code = self._exec_cmd(generate_checksum_args)
531587
if error_code == PDCErrorCode.Success:
532588
logger.info(f"check_sum {path} successfully")
589+
else:
590+
logger.error(f"[Error] [pdc_sdk] check_sum {path} failed, error code: {error_code}")
591+
self._pdc_backup_failed_directory(path)
533592
except Exception as e:
534-
logger.error(f"exec cmd {generate_checksum_args} with error: {e}")
593+
logger.error(f"[Error] [pdc_sdk] exec cmd {generate_checksum_args} with error: {e}")
594+
self._pdc_backup_failed_directory(path)
535595
return PDCErrorCode.CheckSumCommandFail
536596
return error_code
537597

@@ -560,8 +620,10 @@ def _clean_tmp_files(self, tmp_files: List[str]):
560620
PDCErrorCode.AFSToolsNotExist: "afs tools not exist",
561621
PDCErrorCode.TrainConfigNotExist: "train config not exist",
562622
PDCErrorCode.LocalPathNotExist: "local path not exist",
563-
PDCErrorCode.CommandFail: "download command fail",
623+
PDCErrorCode.CommandFail: "pdc agent command fail",
564624
PDCErrorCode.CalculateHashFail: "calculate hash fail",
565625
PDCErrorCode.InvalidArgument: "invalid argument",
566-
PDCErrorCode.CommandTimeout: "command timeout",
626+
PDCErrorCode.CommandTimeout: "pdc agent command timeout",
627+
PDCErrorCode.CheckSumCommandFail: "checksum command fail",
628+
PDCErrorCode.CopyTreeFailed: "copy directory failed",
567629
}

0 commit comments

Comments
 (0)