Skip to content

Commit 74bc7ad

Browse files
committed
bugfix
1 parent 62b74d8 commit 74bc7ad

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

llm/alignment/dpo/run_dpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def main():
137137
)
138138
model.config.use_flash_attention = True
139139

140-
if not any(isinstance(model, cls) for cls in flash_mask_support_list):
140+
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
141141
raise NotImplementedError(f"{model.__class__} not support flash mask.")
142142

143143
if model_args.tokenizer_name_or_path is not None:

llm/run_finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def main():
173173
data_args.zero_padding = True
174174
model.config.use_flash_attention = True
175175

176-
if not any(isinstance(model, cls) for cls in flash_mask_support_list):
176+
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
177177
raise NotImplementedError(f"{model.__class__} not support flash mask.")
178178

179179
if training_args.do_train and model_args.neftune:

0 commit comments

Comments
 (0)