File tree Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -137,7 +137,7 @@ def main():
137
137
)
138
138
model .config .use_flash_attention = True
139
139
140
- if not any (isinstance (model , cls ) for cls in flash_mask_support_list ):
140
+ if model_args . flash_mask and not any (isinstance (model , cls ) for cls in flash_mask_support_list ):
141
141
raise NotImplementedError (f"{ model .__class__ } not support flash mask." )
142
142
143
143
if model_args .tokenizer_name_or_path is not None :
Original file line number Diff line number Diff line change @@ -173,7 +173,7 @@ def main():
173
173
data_args .zero_padding = True
174
174
model .config .use_flash_attention = True
175
175
176
- if not any (isinstance (model , cls ) for cls in flash_mask_support_list ):
176
+ if model_args . flash_mask and not any (isinstance (model , cls ) for cls in flash_mask_support_list ):
177
177
raise NotImplementedError (f"{ model .__class__ } not support flash mask." )
178
178
179
179
if training_args .do_train and model_args .neftune :
You can’t perform that action at this time.
0 commit comments