@@ -109,6 +109,15 @@ def get_master_port(real_launcher=False):
109
109
require_fsdp_version = partial (require_fsdp , min_version = FSDP_PYTORCH_VERSION )
110
110
111
111
112
+ FSDP2_ACCELERATE_VERSION = "1.6.0"
113
+ require_accelerate_fsdp2 = partial (require_accelerate , min_version = FSDP2_ACCELERATE_VERSION )
114
+ require_fsdp_v2_version = require_fsdp
115
+ if is_accelerate_available (min_version = FSDP2_ACCELERATE_VERSION ):
116
+ from accelerate .utils .constants import FSDP2_PYTORCH_VERSION
117
+
118
+ require_fsdp_v2_version = partial (require_fsdp , min_version = FSDP2_PYTORCH_VERSION )
119
+
120
+
112
121
def get_launcher (distributed = False , use_accelerate = False ):
113
122
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
114
123
# - it won't be able to handle that
@@ -316,6 +325,73 @@ def test_fsdp_cpu_offloading(self):
316
325
except : # noqa
317
326
raise AssertionError ("CPU offloading failed with FSDP!" )
318
327
328
+ @require_torch_multi_accelerator
329
+ @slow
330
+ @require_fsdp
331
+ @require_fsdp_v2_version
332
+ @require_accelerate_fsdp2
333
+ def test_accelerate_fsdp2_integration (self ):
334
+ output_dir = self .get_auto_remove_tmp_dir ("./xxx" , after = False )
335
+ sharding_strategy = "full_shard"
336
+ use_accelerate = True
337
+
338
+ num_gpus = min (2 , backend_device_count (torch_device ))
339
+ master_port = get_master_port (real_launcher = True )
340
+ launcher = f"""accelerate launch
341
+ --num_processes { num_gpus }
342
+ --main_process_port { master_port }
343
+ --use_fsdp
344
+ --fsdp_version 2
345
+ --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP
346
+ --fsdp_state_dict_type SHARDED_STATE_DICT
347
+ --fsdp_transformer_layer_cls_to_wrap BertLayer""" .split ()
348
+ args = self .get_base_args (output_dir , 2 , 25 ).split ()
349
+ script = [f"{ self .examples_dir_str } /pytorch/text-classification/run_glue.py" ]
350
+ logs = self .run_cmd_and_get_logs (use_accelerate , sharding_strategy , launcher , script , args , output_dir )
351
+
352
+ # resume from ckpt
353
+ checkpoint = os .path .join (output_dir , "checkpoint-115" )
354
+ resume_args = args + f"--resume_from_checkpoint { checkpoint } " .split ()
355
+
356
+ is_fsdp_ckpt = os .path .isdir (checkpoint ) and (
357
+ # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
358
+ any (
359
+ FSDP_MODEL_NAME in folder_name
360
+ for folder_name in os .listdir (checkpoint )
361
+ if os .path .isdir (os .path .join (checkpoint , folder_name ))
362
+ )
363
+ # this checks the FSDP state dict when `FULL_STATE_DICT` is used
364
+ or os .path .isfile (os .path .join (checkpoint , f"{ FSDP_MODEL_NAME } .bin" ))
365
+ )
366
+ self .assertTrue (is_fsdp_ckpt )
367
+
368
+ logs_resume = self .run_cmd_and_get_logs (
369
+ use_accelerate , sharding_strategy , launcher , script , resume_args , output_dir
370
+ )
371
+
372
+ for log , log1 in zip (logs , logs_resume ):
373
+ if "learning_rate" in log :
374
+ self .assertAlmostEqual (log ["learning_rate" ], log1 ["learning_rate" ], delta = 1e-5 )
375
+
376
+ @require_torch_multi_accelerator
377
+ @slow
378
+ @require_fsdp
379
+ @require_fsdp_v2_version
380
+ @require_accelerate_fsdp2
381
+ def test_fsdp2_cpu_offloading (self ):
382
+ # TODO: This file is missing and should be added or the test should be removed
383
+ if not os .path .exists ("utils/testing_scripts/fsdp_cpu_offloading.py" ):
384
+ raise unittest .SkipTest ("FSDP 2 CPU offloading script not found!" )
385
+
386
+ try :
387
+ subprocess .run (
388
+ "accelerate launch --fsdp_version 2 utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml" ,
389
+ shell = True ,
390
+ check = True ,
391
+ )
392
+ except : # noqa
393
+ raise AssertionError ("CPU offloading failed with FSDP!" )
394
+
319
395
def run_cmd_and_get_logs (self , use_accelerate , sharding_strategy , launcher , script , args , output_dir ):
320
396
if not use_accelerate :
321
397
fsdp_args = [
0 commit comments