diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index f8ef9a8a74ab..c0cbfc713857 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -23,6 +23,7 @@ from ..utils import deprecate, is_transformers_available, logging from .single_file_utils import ( SingleFileComponentError, + _is_legacy_scheduler_kwargs, _is_model_weights_in_cached_folder, _legacy_load_clip_tokenizer, _legacy_load_safety_checker, @@ -42,7 +43,6 @@ # Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"] - if is_transformers_available(): import transformers from transformers import PreTrainedModel, PreTrainedTokenizer @@ -135,7 +135,7 @@ def load_single_file_sub_model( class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only ) - elif is_diffusers_scheduler and is_legacy_loading: + elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)): loaded_sub_model = _legacy_load_scheduler( class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs ) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 9c2a2cbf2942..d099f722187a 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -269,6 +269,7 @@ ] OPEN_CLIP_PREFIX = "conditioner.embedders.0.model." LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 +SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"] VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] @@ -318,6 +319,10 @@ def _is_model_weights_in_cached_folder(cached_folder, name): return weights_exist +def _is_legacy_scheduler_kwargs(kwargs): + return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys()) + + def load_single_file_checkpoint( pretrained_model_link_or_path, force_download=False, @@ -1477,14 +1482,22 @@ def _legacy_load_scheduler( if scheduler_type is not None: deprecation_message = ( - "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`." + "Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n" + "Example:\n\n" + "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n" + "scheduler = DDIMScheduler()\n" + "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n" ) deprecate("scheduler_type", "1.0.0", deprecation_message) if prediction_type is not None: deprecation_message = ( - "Please configure an instance of a Scheduler with the appropriate `prediction_type` " - "and pass the object directly to the `scheduler` argument in `from_single_file`." + "Please configure an instance of a Scheduler with the appropriate `prediction_type` and " + "pass the object directly to the `scheduler` argument in `from_single_file`.\n\n" + "Example:\n\n" + "from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n" + 'scheduler = DDIMScheduler(prediction_type="v_prediction")\n' + "pipe = StableDiffusionPipeline.from_single_file(, scheduler=scheduler)\n" ) deprecate("prediction_type", "1.0.0", deprecation_message)