Skip to content

Improve downloads of sharded variants #9869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,20 @@ def convert_to_variant(filename):
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
return variant_filename

components_with_variant = set()
for filename in variant_filenames:
if not len(filename.split("/")) == 2:
continue
component, component_filename = filename.split("/")
components_with_variant.add(component)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not len(filename.split("/")) == 2:
continue
component, component_filename = filename.split("/")
components_with_variant.add(component)
components_with_variant.add(filename.split("/")[0])

I think we should match the logic to find "component" between variant_filenames and non_variant_filenames

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh I realize you actually just updated this to account for files that are not put inside any subfolders
maybe make this a in-line function find_component or something and do the same the for f in non_variant_filenames too


for f in non_variant_filenames:
variant_filename = convert_to_variant(f)
if variant_filename not in usable_filenames:
component = f.split("/")[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work in case we're loading from a non-subfolder location?

# If a component already has a variant skip including any other files
if component in components_with_variant:
continue
# If a variant version of a file doesn't exist add the file to the allowed patterns list
if convert_to_variant(f) not in variant_filenames:
usable_filenames.add(f)

return usable_filenames, variant_filenames
Expand Down
Loading