From 042ae8199529b5a0551807dc97ce3f1bdb9f4cf7 Mon Sep 17 00:00:00 2001 From: Tian <121000916+SylarTiaNII@users.noreply.github.com> Date: Fri, 13 Dec 2024 14:24:00 +0800 Subject: [PATCH] [LLM] support flash device on static model (#9619) * [LLM] support flash device on static model * [LLM] adapt pdc sdk --- paddlenlp/utils/downloader.py | 89 ++++++++++++++++++++++++++++++++++- paddlenlp/utils/pdc_sdk.py | 76 +++++++++++++++++++++++++++--- 2 files changed, 156 insertions(+), 9 deletions(-) diff --git a/paddlenlp/utils/downloader.py b/paddlenlp/utils/downloader.py index a382a4dd265b..66cf6f7ab23e 100644 --- a/paddlenlp/utils/downloader.py +++ b/paddlenlp/utils/downloader.py @@ -24,6 +24,7 @@ from collections import OrderedDict from typing import Optional, Union +import paddle.distributed as dist import requests from filelock import FileLock from huggingface_hub import get_hf_file_metadata, hf_hub_url @@ -33,7 +34,13 @@ from .env import DOWNLOAD_SERVER, FAILED_STATUS, SUCCESS_STATUS from .fault_tolerance import PDC_DOWNLOAD_ERROR from .log import logger -from .pdc_sdk import PDCErrorCode, PDCErrorMessageMap, pdc_tool +from .pdc_sdk import ( + FLASH_DEVICE, + PDCErrorCode, + PDCErrorMessageMap, + pdc_flash_device_available, + pdc_tool, +) __all__ = ["get_weights_path_from_url"] @@ -487,7 +494,7 @@ def download_from_pdc(remote_path, local_path, timeout): """ try: - base_dir, _ = os.path.split(os.path.normpath(remote_path)) + base_dir, _ = os.path.split(os.path.normpath(local_path)) if not os.path.exists(base_dir) and base_dir != "": os.makedirs(base_dir, exist_ok=True) except Exception as e: @@ -505,3 +512,81 @@ def download_from_pdc(remote_path, local_path, timeout): raise RuntimeError( f"{PDC_DOWNLOAD_ERROR}; Error occurred when trying to download object from PDC, remote_path: {remote_path}, local_path: {local_path}, timeout: {timeout}; error details: {PDCErrorMessageMap[result]}" ) + + +def get_static_model_on_pdc(remote_path, local_path, timeout, enable_flash_device=False): + """ + Get static model from PDC. Use flash device if possible. + This function has to be called after distributed env is initialized in distributed mode. + Args: + remote_path (`str`): + remote path url for download + local_path (`str`): + local path to place downloaded object + timeout (`int`): + max wait time for download + enable_flash_device (`bool`): + Whether to use flash device + Returns: + str: path to load static model + """ + try: + base_dir, target_dir = os.path.split(os.path.normpath(local_path)) + if not os.path.exists(base_dir) and base_dir != "": + os.makedirs(base_dir, exist_ok=True) + except Exception as e: + raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}") + + assert target_dir != ".", f"{PDC_DOWNLOAD_ERROR}, illegal local_path: {local_path}." + + flash_path = os.path.join(FLASH_DEVICE, target_dir) + persistent_path = local_path + + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + if device_id != 0: + logger.info("Waiting local process 0...") + dist.barrier() + return flash_path if (enable_flash_device and os.path.exists(flash_path)) else persistent_path + + # step 1: load from flash device if possible + need_download_from_remote = True + need_backup_to_flash = False + if enable_flash_device and pdc_flash_device_available(): + logger.info(f"flash device is available, checking status on {flash_path}...") + # skip download SC as default when flash device is available + need_download_from_remote = False + if os.path.exists(flash_path) and pdc_tool.pdc_flash_do_check(flash_path) == PDCErrorCode.Success: + logger.info("Static model checked successfully on flash device, ready to load...") + else: + logger.warning( + "flash device is available but no valid static model found on flash device, need to download from remote." + ) + need_download_from_remote = True + need_backup_to_flash = True + else: + logger.info("Flash device is not enabled or available, will download static model from remote.") + + # step 2: download from remote if neccesary + if need_download_from_remote: + logger.info("Beging download static model from remote...") + download_from_pdc(remote_path, persistent_path, timeout) + logger.info(f"downloaded static model from remote, path:{persistent_path}") + + # step 3: backup to flash device if flash device is available + if enable_flash_device and need_backup_to_flash: + result = pdc_tool.pdc_backup_to_flash_device(persistent_path, flash_path) + if result == PDCErrorCode.Success: + logger.info(f"Backup static model to flash device {flash_path} successfully.") + else: + logger.error(f"Backup static model to flash device failed, error details: {PDCErrorMessageMap[result]}.") + + # step 4: return flash path if available, otherwise return persistent path + if dist.get_world_size() > 1: + logger.info("Local node process done, waiting other nodes...") + dist.barrier() + if enable_flash_device and os.path.exists(flash_path): + logger.info(f"static model is ready on flash device, path: {flash_path}") + return flash_path + else: + logger.info(f"static model is only ready on persistent storage, path: {persistent_path}") + return persistent_path diff --git a/paddlenlp/utils/pdc_sdk.py b/paddlenlp/utils/pdc_sdk.py index 028850e4d388..c306eedd92c7 100644 --- a/paddlenlp/utils/pdc_sdk.py +++ b/paddlenlp/utils/pdc_sdk.py @@ -15,9 +15,11 @@ import json import os import queue +import shutil import subprocess import threading import time +from distutils.dir_util import copy_tree from enum import Enum from typing import List @@ -28,6 +30,13 @@ TRAIN_CONFIG = "/root/paddlejob/workspace/env_run/longjob/train.conf" TAR_BIN = "tar" +FLASH_DEVICE = os.getenv("PDC_FLASH_DEVICE", "/shared/dev/shm/flash") + + +def pdc_flash_device_available(): + # TODO(@gexiao): need better check + return os.path.exists(FLASH_DEVICE) + class PDCErrorCode(Enum): """Error Code For PDCTools usage""" @@ -48,6 +57,7 @@ class PDCErrorCode(Enum): InvalidArgument = 1503 CommandTimeout = 1504 CheckSumCommandFail = 1505 + CopyTreeFailed = 1506 UnknownError = 1999 @@ -493,14 +503,60 @@ def _download_file(self, remote_path: str, local_path: str) -> PDCErrorCode: raise Exception(f"exec cmd {download_cmd_args} with error: {e}") return error_code - def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode: + def _pdc_backup_failed_directory(self, path): + base_dir, target_path = os.path.split(os.path.normpath(path)) + failed_path = os.path.join(base_dir, f"{target_path}_failed") + if os.path.exists(path): + if os.path.exists(failed_path): + shutil.rmtree(failed_path) + # Backup failed files for debug + os.rename(path, failed_path) + + def pdc_backup_to_flash_device(self, persistent_path: str, flash_device_path: str) -> PDCErrorCode: + """backup data to flash device + + Args: + persistent_path str: persistent path + flash_device_path str: flash device path + """ + if not os.path.exists(persistent_path): + logger.error(f"{persistent_path} not exist") + return PDCErrorCode.LocalPathNotExist + + logger.info("starting backup to flash device...") + + # step 1: generate checksum for recovery + result = self.pdc_generate_dir_checksum(persistent_path) + if result != PDCErrorCode.Success: + logger.error(f"[Error] [pdc_sdk] generating checksum for {persistent_path} failed") + return result + + # step 2: copy persistent data to flash device + try: + copy_tree(persistent_path, flash_device_path) + logger.info(f"backup {persistent_path} to {flash_device_path} successed.") + except Exception as e: + logger.error(f"[Error] [pdc_sdk] copy tree {persistent_path} to {flash_device_path} failed, error: {e}") + self._pdc_backup_failed_directory(flash_device_path) + return PDCErrorCode.CopyTreeFailed + + # step 3: do checksum for storage on flash device + result = self.pdc_flash_do_check(flash_device_path) + if result == PDCErrorCode.Success: + return result + + logger.error(f"[Error] [pdc_sdk] checksum failed on {flash_device_path} after copy, backup for debug") + self._pdc_backup_failed_directory(flash_device_path) + return result + + def pdc_generate_dir_checksum(self, path: str) -> PDCErrorCode: """ Args :param localPath: :return: """ if not os.path.exists(path): - logger.error(f"pdc_fc_generate_checksum gi{path} not exist") + logger.error(f"pdc_generate_dir_checksum gi{path} not exist") return PDCErrorCode.CommandFail generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "generateSum", "-path", f"{path}"] error_code = PDCErrorCode.Success @@ -514,14 +570,14 @@ def pdc_fc_generate_checksum(self, path: str) -> PDCErrorCode: return PDCErrorCode.CheckSumCommandFail return error_code - def pdc_fc_do_check(self, path: str) -> PDCErrorCode: + def pdc_flash_do_check(self, path: str) -> PDCErrorCode: """ Args :param localPath: :return: """ if not os.path.exists(path): - logger.error(f"pdc_fc_do_check {path} not exist") + logger.error(f"pdc_flash_do_check {path} not exist") return PDCErrorCode.CommandFail generate_checksum_args = [self._pdc_agent_bin, "-mode", "command", "-type", "checkSum", "-path", f"{path}"] error_code = PDCErrorCode.Success @@ -530,8 +586,12 @@ def pdc_fc_do_check(self, path: str) -> PDCErrorCode: res, error_code = self._exec_cmd(generate_checksum_args) if error_code == PDCErrorCode.Success: logger.info(f"check_sum {path} successfully") + else: + logger.error(f"[Error] [pdc_sdk] check_sum {path} failed, error code: {error_code}") + self._pdc_backup_failed_directory(path) except Exception as e: - logger.error(f"exec cmd {generate_checksum_args} with error: {e}") + logger.error(f"[Error] [pdc_sdk] exec cmd {generate_checksum_args} with error: {e}") + self._pdc_backup_failed_directory(path) return PDCErrorCode.CheckSumCommandFail return error_code @@ -560,8 +620,10 @@ def _clean_tmp_files(self, tmp_files: List[str]): PDCErrorCode.AFSToolsNotExist: "afs tools not exist", PDCErrorCode.TrainConfigNotExist: "train config not exist", PDCErrorCode.LocalPathNotExist: "local path not exist", - PDCErrorCode.CommandFail: "download command fail", + PDCErrorCode.CommandFail: "pdc agent command fail", PDCErrorCode.CalculateHashFail: "calculate hash fail", PDCErrorCode.InvalidArgument: "invalid argument", - PDCErrorCode.CommandTimeout: "command timeout", + PDCErrorCode.CommandTimeout: "pdc agent command timeout", + PDCErrorCode.CheckSumCommandFail: "checksum command fail", + PDCErrorCode.CopyTreeFailed: "copy directory failed", }