Skip to content

Commit 99ea497

Browse files
committed
Put the legacy processing code back
1 parent 0d29bc3 commit 99ea497

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

src/transformers/models/llava/configuration_llava.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class LlavaConfig(PretrainedConfig):
4848
Can be one of `"default"` or `"full"`.
4949
vision_feature_layer (`int`, *optional*, defaults to -2):
5050
The index of the layer to select the vision feature.
51+
image_seq_length (`int`, *optional*, defaults to 576):
52+
Sequence length of one image embedding.
5153
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
5254
Whether to use bias in the multimodal projector.
5355
@@ -84,12 +86,14 @@ def __init__(
8486
projector_hidden_act="gelu",
8587
vision_feature_select_strategy="default",
8688
vision_feature_layer=-2,
89+
image_seq_length=576,
8790
multimodal_projector_bias=True,
8891
**kwargs,
8992
):
9093
self.ignore_index = ignore_index
9194
self.image_token_index = image_token_index
9295
self.projector_hidden_act = projector_hidden_act
96+
self.image_seq_length = image_seq_length
9397

9498
if vision_feature_select_strategy not in ["default", "full"]:
9599
raise ValueError(

src/transformers/models/llava/modeling_llava.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,17 @@ def forward(
464464
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
465465
)
466466

467+
legacy_processing = False
467468
if inputs_embeds is None:
468469
inputs_embeds = self.get_input_embeddings()(input_ids)
469470

471+
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
472+
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
473+
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
474+
legacy_processing = (
475+
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
476+
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
477+
470478
image_features = None
471479
if pixel_values is not None:
472480
image_features = self.get_image_features(
@@ -475,7 +483,53 @@ def forward(
475483
vision_feature_select_strategy=vision_feature_select_strategy,
476484
)
477485

478-
if image_features is not None:
486+
if legacy_processing:
487+
logger.warning_once(
488+
"Expanding inputs for image tokens in LLaVa should be done in processing. "
489+
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
490+
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
491+
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
492+
)
493+
# prefill stage vs decoding stage (legacy behavior copied)
494+
if input_ids.shape[1] != 1:
495+
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
496+
image_features, inputs_embeds, input_ids, attention_mask, labels
497+
)
498+
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
499+
else:
500+
# Retrieve the first layer to inspect the logits and mask out the hidden states
501+
# that are set to 0
502+
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
503+
504+
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
505+
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
506+
507+
# Get the target length
508+
target_length = input_ids.shape[1]
509+
past_length = first_layer_past_key_value.shape[-1]
510+
511+
extended_attention_mask = torch.ones(
512+
(attention_mask.shape[0], past_length),
513+
dtype=attention_mask.dtype,
514+
device=attention_mask.device,
515+
)
516+
517+
# Filter out only the tokens that can be un-attended, this can happen
518+
# if one uses Llava + Fused modules where the cache on the
519+
# first iteration is already big enough, or if one passes custom cache
520+
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
521+
new_batch_index = batch_index[valid_indices]
522+
new_non_attended_tokens = non_attended_tokens[valid_indices]
523+
524+
# Zero-out the places where we don't need to attend
525+
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
526+
527+
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
528+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
529+
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:]
530+
531+
# TODO: @raushan retain only the new behavior after v4.47
532+
elif image_features is not None:
479533
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
480534
n_image_features = image_features.shape[0] * image_features.shape[1]
481535

0 commit comments

Comments
 (0)