Skip to content

Commit a9b636a

Browse files
byi8220zucchini-nlp
authored andcommitted
Disable delay_optimizer_creation in Trainer to support fsdp2 (huggingface#37147)
* github why you do this * fix * make fixup * disable cpu offload test * fixup * tmp reworks * git branch movement * make fixup * add require_fsdp_v2_version * dep issues * update ruff and fixup
1 parent 51d7aaf commit a9b636a

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

src/transformers/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,6 +2313,11 @@ def _inner_training_loop(
23132313

23142314
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
23152315

2316+
# Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
2317+
is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
2318+
if is_fsdp2:
2319+
delay_optimizer_creation = False
2320+
23162321
# We need to reset the scheduler, as its parameters may be different on subsequent calls
23172322
if self._created_lr_scheduler:
23182323
self.lr_scheduler = None

tests/fsdp/test_fsdp.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,15 @@ def get_master_port(real_launcher=False):
109109
require_fsdp_version = partial(require_fsdp, min_version=FSDP_PYTORCH_VERSION)
110110

111111

112+
FSDP2_ACCELERATE_VERSION = "1.6.0"
113+
require_accelerate_fsdp2 = partial(require_accelerate, min_version=FSDP2_ACCELERATE_VERSION)
114+
require_fsdp_v2_version = require_fsdp
115+
if is_accelerate_available(min_version=FSDP2_ACCELERATE_VERSION):
116+
from accelerate.utils.constants import FSDP2_PYTORCH_VERSION
117+
118+
require_fsdp_v2_version = partial(require_fsdp, min_version=FSDP2_PYTORCH_VERSION)
119+
120+
112121
def get_launcher(distributed=False, use_accelerate=False):
113122
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
114123
# - it won't be able to handle that
@@ -316,6 +325,73 @@ def test_fsdp_cpu_offloading(self):
316325
except: # noqa
317326
raise AssertionError("CPU offloading failed with FSDP!")
318327

328+
@require_torch_multi_accelerator
329+
@slow
330+
@require_fsdp
331+
@require_fsdp_v2_version
332+
@require_accelerate_fsdp2
333+
def test_accelerate_fsdp2_integration(self):
334+
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
335+
sharding_strategy = "full_shard"
336+
use_accelerate = True
337+
338+
num_gpus = min(2, backend_device_count(torch_device))
339+
master_port = get_master_port(real_launcher=True)
340+
launcher = f"""accelerate launch
341+
--num_processes {num_gpus}
342+
--main_process_port {master_port}
343+
--use_fsdp
344+
--fsdp_version 2
345+
--fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP
346+
--fsdp_state_dict_type SHARDED_STATE_DICT
347+
--fsdp_transformer_layer_cls_to_wrap BertLayer""".split()
348+
args = self.get_base_args(output_dir, 2, 25).split()
349+
script = [f"{self.examples_dir_str}/pytorch/text-classification/run_glue.py"]
350+
logs = self.run_cmd_and_get_logs(use_accelerate, sharding_strategy, launcher, script, args, output_dir)
351+
352+
# resume from ckpt
353+
checkpoint = os.path.join(output_dir, "checkpoint-115")
354+
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
355+
356+
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
357+
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
358+
any(
359+
FSDP_MODEL_NAME in folder_name
360+
for folder_name in os.listdir(checkpoint)
361+
if os.path.isdir(os.path.join(checkpoint, folder_name))
362+
)
363+
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
364+
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
365+
)
366+
self.assertTrue(is_fsdp_ckpt)
367+
368+
logs_resume = self.run_cmd_and_get_logs(
369+
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
370+
)
371+
372+
for log, log1 in zip(logs, logs_resume):
373+
if "learning_rate" in log:
374+
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
375+
376+
@require_torch_multi_accelerator
377+
@slow
378+
@require_fsdp
379+
@require_fsdp_v2_version
380+
@require_accelerate_fsdp2
381+
def test_fsdp2_cpu_offloading(self):
382+
# TODO: This file is missing and should be added or the test should be removed
383+
if not os.path.exists("utils/testing_scripts/fsdp_cpu_offloading.py"):
384+
raise unittest.SkipTest("FSDP 2 CPU offloading script not found!")
385+
386+
try:
387+
subprocess.run(
388+
"accelerate launch --fsdp_version 2 utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml",
389+
shell=True,
390+
check=True,
391+
)
392+
except: # noqa
393+
raise AssertionError("CPU offloading failed with FSDP!")
394+
319395
def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir):
320396
if not use_accelerate:
321397
fsdp_args = [

0 commit comments

Comments
 (0)