@@ -464,9 +464,17 @@ def forward(
464
464
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
465
465
)
466
466
467
+ legacy_processing = False
467
468
if inputs_embeds is None :
468
469
inputs_embeds = self .get_input_embeddings ()(input_ids )
469
470
471
+ # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
472
+ # not very reliable, but we don't expect one to actually pass 500+ images for one prompt
473
+ # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
474
+ legacy_processing = (
475
+ (input_ids == self .config .image_token_index ).sum (1 ).max () < self .config .image_seq_length
476
+ ) or (input_ids .shape [- 1 ] == 1 and pixel_values is not None )
477
+
470
478
image_features = None
471
479
if pixel_values is not None :
472
480
image_features = self .get_image_features (
@@ -475,7 +483,53 @@ def forward(
475
483
vision_feature_select_strategy = vision_feature_select_strategy ,
476
484
)
477
485
478
- if image_features is not None :
486
+ if legacy_processing :
487
+ logger .warning_once (
488
+ "Expanding inputs for image tokens in LLaVa should be done in processing. "
489
+ "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
490
+ "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
491
+ "Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
492
+ )
493
+ # prefill stage vs decoding stage (legacy behavior copied)
494
+ if input_ids .shape [1 ] != 1 :
495
+ inputs_embeds , attention_mask , labels , position_ids = self ._merge_input_ids_with_image_features (
496
+ image_features , inputs_embeds , input_ids , attention_mask , labels
497
+ )
498
+ cache_position = torch .arange (attention_mask .shape [1 ], device = attention_mask .device )
499
+ else :
500
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
501
+ # that are set to 0
502
+ first_layer_past_key_value = past_key_values [0 ][0 ][:, :, :, 0 ]
503
+
504
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
505
+ batch_index , non_attended_tokens = torch .where (first_layer_past_key_value .float ().sum (- 2 ) == 0 )
506
+
507
+ # Get the target length
508
+ target_length = input_ids .shape [1 ]
509
+ past_length = first_layer_past_key_value .shape [- 1 ]
510
+
511
+ extended_attention_mask = torch .ones (
512
+ (attention_mask .shape [0 ], past_length ),
513
+ dtype = attention_mask .dtype ,
514
+ device = attention_mask .device ,
515
+ )
516
+
517
+ # Filter out only the tokens that can be un-attended, this can happen
518
+ # if one uses Llava + Fused modules where the cache on the
519
+ # first iteration is already big enough, or if one passes custom cache
520
+ valid_indices = non_attended_tokens < extended_attention_mask .size (- 1 )
521
+ new_batch_index = batch_index [valid_indices ]
522
+ new_non_attended_tokens = non_attended_tokens [valid_indices ]
523
+
524
+ # Zero-out the places where we don't need to attend
525
+ extended_attention_mask [new_batch_index , new_non_attended_tokens ] = 0
526
+
527
+ attention_mask = torch .cat ((extended_attention_mask , attention_mask [:, - target_length :]), dim = 1 )
528
+ position_ids = torch .sum (attention_mask , dim = 1 ).unsqueeze (- 1 ) - 1
529
+ cache_position = torch .arange (attention_mask .shape [1 ], device = attention_mask .device )[- target_length :]
530
+
531
+ # TODO: @raushan retain only the new behavior after v4.47
532
+ elif image_features is not None :
479
533
n_image_tokens = (input_ids == self .config .image_token_index ).sum ().item ()
480
534
n_image_features = image_features .shape [0 ] * image_features .shape [1 ]
481
535
0 commit comments