Skip to content

Commit 29f322d

Browse files
authored
[generate, cache] handle more complex device maps (#37014)
1 parent fb8e6c5 commit 29f322d

File tree

2 files changed

+109
-10
lines changed

2 files changed

+109
-10
lines changed

src/transformers/generation/utils.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,11 +1656,10 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
16561656
model_kwargs["cache_position"] = cache_position
16571657
return model_kwargs
16581658

1659-
def _get_layer_device_map_for_cache_init(self):
1659+
def _get_layer_device_map_for_cache_init(self) -> Optional[Dict[int, Union[str, int]]]:
16601660
"""
1661-
Taken from `dispatch_model` from accelerate.
1662-
This is needed here if we don't want to make changes in accelerate in order to save execution_device
1663-
For offloaded case, we need to get the execution device, not just the device where it is offloaded
1661+
Returns the device map for each decoder layer, to allocate the cache on the right device.
1662+
Inspired from `dispatch_model` in accelerate.
16641663
"""
16651664
execution_device_map = None
16661665

@@ -1674,17 +1673,62 @@ def _get_layer_device_map_for_cache_init(self):
16741673
for name, device in self.hf_device_map.items()
16751674
}
16761675

1677-
num_hidden_layers = self.config.get_text_config().num_hidden_layers
1676+
# No `execution_device_map` -> rely on `self.device` to allocate the cache
16781677
if execution_device_map is None:
16791678
return None
1680-
elif len(execution_device_map) == 1 and "" in execution_device_map:
1679+
1680+
# Single device for all layers
1681+
num_hidden_layers = self.config.get_text_config().num_hidden_layers
1682+
if len(execution_device_map) == 1 and "" in execution_device_map:
16811683
return dict.fromkeys(range(num_hidden_layers), execution_device_map[""])
1684+
1685+
# Multiple devices in `execution_device_map` -> we need to map decoder layers to the correct device.
16821686
layer_device_map = {}
1683-
for layer in execution_device_map:
1684-
for idx in range(num_hidden_layers):
1685-
if f".{idx}." in f"{layer}.":
1686-
layer_device_map[idx] = execution_device_map[layer]
1687+
# Case 1: The model has a `get_decoder` method, we can use it to find the decoder name.
1688+
if hasattr(self, "get_decoder"):
1689+
decoder_name = None
1690+
for name, module in self.named_modules():
1691+
if module is self.get_decoder():
1692+
decoder_name = name
16871693
break
1694+
if decoder_name is None:
1695+
raise RuntimeError(
1696+
"`model.get_decoder()` is not returning a named module of the model. This is unexpected, please "
1697+
"open an issue on GitHub."
1698+
)
1699+
1700+
decoder_mapped_modules = [
1701+
module_name for module_name in execution_device_map.keys() if decoder_name in module_name
1702+
]
1703+
# The decoder name may be present in `execution_device_map` in two forms:
1704+
# a) each layer has a device mapping
1705+
if len(decoder_mapped_modules) >= num_hidden_layers:
1706+
for idx in range(num_hidden_layers):
1707+
for module_name in decoder_mapped_modules:
1708+
if f".{idx}." in f"{module_name}.":
1709+
layer_device_map[idx] = execution_device_map[module_name]
1710+
break
1711+
1712+
# b) the whole module is mapped to a single device. If the decoder name is NOT present in the device map,
1713+
# then the mapping is done in a parent module
1714+
else:
1715+
while True:
1716+
if decoder_name in execution_device_map:
1717+
layer_device_map = dict.fromkeys(range(num_hidden_layers), execution_device_map[decoder_name])
1718+
break
1719+
elif "." in decoder_name:
1720+
decoder_name = decoder_name.rsplit(".", 1)[0] # gets the name of the parent module
1721+
else:
1722+
raise RuntimeError(f"Decoder name {decoder_name} not found in execution device map")
1723+
1724+
# Case 2: Legacy code path: assume the decoder layers are named as `(...).X` (X being the layer index)
1725+
else:
1726+
for layer in execution_device_map:
1727+
for idx in range(num_hidden_layers):
1728+
if f".{idx}." in f"{layer}.":
1729+
layer_device_map[idx] = execution_device_map[layer]
1730+
break
1731+
16881732
for idx in range(num_hidden_layers):
16891733
if idx not in layer_device_map:
16901734
raise RuntimeError(f"layer {idx} has not been mapped to a device.")

tests/generation/test_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757
from transformers import (
5858
AutoModelForCausalLM,
59+
AutoModelForImageTextToText,
5960
AutoModelForSeq2SeqLM,
6061
AutoModelForSpeechSeq2Seq,
6162
AutoModelForVision2Seq,
@@ -4720,6 +4721,60 @@ def test_generate_vision2text_conditioning(self):
47204721
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
47214722
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
47224723

4724+
@slow
4725+
@require_torch_gpu
4726+
def test_cache_device_map_with_vision_layer_device_map(self):
4727+
"""
4728+
Test that the cache device map is correctly set when the vision layer has a device map. Regression test for
4729+
#36942
4730+
"""
4731+
# gemma 3 uses hybrid cache, which can be compiled -> needs a device map at allocation time
4732+
model_id = "google/gemma-3-4b-it"
4733+
4734+
# important part of this device map: the `.layers.` pattern is NOT present in the decoder
4735+
device_map = {
4736+
"vision_tower.vision_model.embeddings": 0,
4737+
"vision_tower.vision_model.encoder.layers.0": 0,
4738+
"vision_tower.vision_model.encoder.layers.1": 0,
4739+
"vision_tower.vision_model.encoder.layers.2": 0,
4740+
"vision_tower.vision_model.encoder.layers.3": 0,
4741+
"vision_tower.vision_model.encoder.layers.4": 0,
4742+
"vision_tower.vision_model.encoder.layers.5": 0,
4743+
"vision_tower.vision_model.encoder.layers.6": 0,
4744+
"vision_tower.vision_model.encoder.layers.7": 0,
4745+
"vision_tower.vision_model.encoder.layers.8": 0,
4746+
"vision_tower.vision_model.encoder.layers.9": 0,
4747+
"vision_tower.vision_model.encoder.layers.10": 0,
4748+
"vision_tower.vision_model.encoder.layers.11": 0,
4749+
"vision_tower.vision_model.encoder.layers.12": 0,
4750+
"vision_tower.vision_model.encoder.layers.13": 0,
4751+
"vision_tower.vision_model.encoder.layers.14": "cpu",
4752+
"vision_tower.vision_model.encoder.layers.15": "cpu",
4753+
"vision_tower.vision_model.encoder.layers.16": "cpu",
4754+
"vision_tower.vision_model.encoder.layers.17": "cpu",
4755+
"vision_tower.vision_model.encoder.layers.18": "cpu",
4756+
"vision_tower.vision_model.encoder.layers.19": "cpu",
4757+
"vision_tower.vision_model.encoder.layers.20": "cpu",
4758+
"vision_tower.vision_model.encoder.layers.21": "cpu",
4759+
"vision_tower.vision_model.encoder.layers.22": "cpu",
4760+
"vision_tower.vision_model.encoder.layers.23": "cpu",
4761+
"vision_tower.vision_model.encoder.layers.24": "cpu",
4762+
"vision_tower.vision_model.encoder.layers.25": "cpu",
4763+
"vision_tower.vision_model.encoder.layers.26": "cpu",
4764+
"vision_tower.vision_model.post_layernorm": "cpu",
4765+
"multi_modal_projector": "cpu",
4766+
"language_model": "cpu",
4767+
}
4768+
4769+
model = AutoModelForImageTextToText.from_pretrained(
4770+
model_id, device_map=device_map, torch_dtype=torch.bfloat16
4771+
)
4772+
tokenizer = AutoTokenizer.from_pretrained(model_id)
4773+
inputs = tokenizer(["This is a text input"], return_tensors="pt").to(model.device)
4774+
4775+
# If the generate doesn't infer the DECODER device map correctly, this will fail
4776+
_ = model.generate(**inputs, max_new_tokens=2, do_sample=False)
4777+
47234778

47244779
@require_torch
47254780
class TokenHealingTestCase(unittest.TestCase):

0 commit comments

Comments
 (0)