Skip to content

Disable delay_optimizer_creation in Trainer to support fsdp2 #37147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 4, 2025

Conversation

byi8220
Copy link
Contributor

@byi8220 byi8220 commented Mar 31, 2025

What does this PR do?

In order to get Trainer support for FSDP2 in accelerate, we have to pass the model and optimizer into Accelerator.prepare() at the same time (https://github.com/huggingface/accelerate/pull/3394/files#r2017637611).

Note: This may not be sufficient for full FSDP2 support in Trainer. This PR might be scrapped and replaced with something more complete if some issues arise.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Testing

2 unit tests were added to tests/fsdp/test_fsdp.py

  1. test_fsdp2_cpu_offloading, which matches test_fsdp_cpu_offloading, but is skipped since the second test appears to be non functioning
  2. test_accelerate_fsdp2_integration which matches test_training_and_can_resume_normally with SHARDED_STATE_DICT (as this is the only config that can run accelerate), and passes.

This will not be caught by CI, as it depends on an unreleased feature, but was instead manually run with the command RUN_SLOW=1 pytest tests/fsdp/test_fsdp.py

# RUN_SLOW=1 pytest tests/fsdp/test_fsdp.py 

tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_full_shard_bf16 PASSED                                                                                                                                             [  4%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_full_shard_fp16 PASSED                                                                                                                                             [  8%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_shard_grad_op_bf16 PASSED                                                                                                                                          [ 13%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_shard_grad_op_fp16 PASSED                                                                                                                                          [ 17%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_with_cpu_offload_0_bf16 PASSED                                                                                                                                     [ 21%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_with_cpu_offload_1_fp16 PASSED                                                                                                                                     [ 26%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_with_gradient_accumulation_full_shard_bf16 PASSED                                                                                                                  [ 30%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_with_gradient_accumulation_full_shard_fp16 PASSED                                                                                                                  [ 34%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_with_gradient_accumulation_shard_grad_op_bf16 PASSED                                                                                                               [ 39%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_basic_run_with_gradient_accumulation_shard_grad_op_fp16 PASSED                                                                                                               [ 43%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_cpu_offloading SKIPPED (FSDP CPU offloading script not found!)                                                                                                          [ 47%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_training_and_can_resume_normally_FULL_STATE_DICT PASSED                                                                                                                      [ 52%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_training_and_can_resume_normally_SHARDED_STATE_DICT PASSED                                                                                                                   [ 56%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_accelerate_fsdp2_integration PASSED                                                                                                                                          [ 60%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp2_cpu_offloading SKIPPED (FSDP 2 CPU offloading script not found!)                                                                                                       [ 65%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_config_full_shard_bf16 PASSED                                                                                                                                           [ 69%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_config_full_shard_fp16 PASSED                                                                                                                                           [ 73%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_config_shard_grad_op_bf16 PASSED                                                                                                                                        [ 78%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_config_shard_grad_op_fp16 PASSED                                                                                                                                        [ 82%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_config_transformers_auto_wrap_full_shard_bf16 PASSED                                                                                                                    [ 86%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_config_transformers_auto_wrap_full_shard_fp16 PASSED                                                                                                                    [ 91%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_config_transformers_auto_wrap_shard_grad_op_bf16 PASSED                                                                                                                 [ 95%]
tests/fsdp/test_fsdp.py::TrainerIntegrationFSDP::test_fsdp_config_transformers_auto_wrap_shard_grad_op_fp16 PASSED                                                                                                                 [100%]

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@S1ro1 S1ro1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. cc @SunMarc for double-check on the added tests.

@byi8220 byi8220 marked this pull request as ready for review March 31, 2025 23:36
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Left a couple of comments. It would be nice to check that all the fsdp tests works with fsdpv2 also @S1ro1 but that can be done in a follow-up PR

@byi8220 byi8220 changed the title [wip] Disable delay_optimizer_creation in Trainer to support fsdp2 Disable delay_optimizer_creation in Trainer to support fsdp2 Apr 1, 2025
@byi8220
Copy link
Contributor Author

byi8220 commented Apr 1, 2025

@SunMarc I've addressed your comments and the test suite passes on torch==2.6.0

I'm a bit worried about how this interacts with torch version 2.5.1 (seems FSDP2 wasn't "officially" supported until 2.6.0), but testing on torch==2.5.1 on my machine, the main branch doesn't even pass

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! make sure to update your ruff version to fix the quaility CI test

@SunMarc
Copy link
Member

SunMarc commented Apr 2, 2025

cc @S1ro1 for FSDP2_PYTORCH_VERSION comment. Maybe we need to update it to 2.6.0 ?

@byi8220
Copy link
Contributor Author

byi8220 commented Apr 2, 2025

Yeah after updating ruff the 2 files that were being changed seem to be unaffected.

@SunMarc SunMarc merged commit a4e55fc into huggingface:main Apr 4, 2025
18 checks passed
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Apr 5, 2025
…ngface#37147)

* github why you do this

* fix

* make fixup

* disable cpu offload test

* fixup

* tmp reworks

* git branch movement

* make fixup

* add require_fsdp_v2_version

* dep issues

* update ruff and fixup
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…ngface#37147)

* github why you do this

* fix

* make fixup

* disable cpu offload test

* fixup

* tmp reworks

* git branch movement

* make fixup

* add require_fsdp_v2_version

* dep issues

* update ruff and fixup
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants