-
Notifications
You must be signed in to change notification settings - Fork 29.4k
Improvements in Gemma2 model card #37076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
387280b
6257220
9911d2b
6ff03c0
2559289
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,36 +14,133 @@ specific language governing permissions and limitations under the License. | |||||||||
rendered properly in your Markdown viewer. | ||||||||||
|
||||||||||
--> | ||||||||||
<div style="float: right;"> | ||||||||||
<div class="flex flex-wrap space-x-1"> | ||||||||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||||||||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white"> | ||||||||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo= | ||||||||||
"> | ||||||||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> | ||||||||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||||||||||
</div> | ||||||||||
</div> | ||||||||||
|
||||||||||
# Gemma2 | ||||||||||
|
||||||||||
<div class="flex flex-wrap space-x-1"> | ||||||||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||||||||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat"> | ||||||||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> | ||||||||||
</div> | ||||||||||
[Gemma 2](https://huggingface.co/papers/2408.00118) is a family of language models with pretrained and instruction-tuned variants, available in 2B, 9B, 27B parameters. The architecture is similar to the previous Gemma, except it features interleaved local attention (4096 tokens) and global attention (8192 tokens) and grouped-query attention (GQA) to increase inference performance. | ||||||||||
|
||||||||||
The 2B and 9B models are trained with knowledge distillation, and the instruction-tuned variant was post-trained with supervised fine-tuning and reinforcement learning. | ||||||||||
|
||||||||||
You can find all the original Gemma 2 checkpoints under the [Gemma 2](https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315) collection. | ||||||||||
|
||||||||||
> [!TIP] | ||||||||||
> Click on the Gemma 2 models in the right sidebar for more examples of how to apply Gemma to different language tasks. | ||||||||||
|
||||||||||
The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line. | ||||||||||
|
||||||||||
<hfoptions id="usage"> | ||||||||||
<hfoption id="Pipeline"> | ||||||||||
|
||||||||||
|
||||||||||
```python | ||||||||||
import torch | ||||||||||
from transformers import pipeline | ||||||||||
|
||||||||||
pipe = pipeline( | ||||||||||
task="text-generation", | ||||||||||
model="google/gemma-2-9b", | ||||||||||
torch_dtype=torch.bfloat16, | ||||||||||
device="cuda", | ||||||||||
) | ||||||||||
|
||||||||||
pipe("Explain quantum computing simply. ", max_new_tokens=50) | ||||||||||
``` | ||||||||||
|
||||||||||
## Overview | ||||||||||
</hfoption> | ||||||||||
<hfoption id="AutoModel"> | ||||||||||
|
||||||||||
```python | ||||||||||
import torch | ||||||||||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||||||||||
|
||||||||||
The Gemma2 model was proposed in [Gemma2: Open Models Based on Gemini Technology and Research](https://blog.google/technology/developers/google-gemma-2/) by Gemma2 Team, Google. | ||||||||||
Two Gemma2 models are released, with parameters sizes of 9 billion (9B) and 27 billion (27B). | ||||||||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") | ||||||||||
model = AutoModelForCausalLM.from_pretrained( | ||||||||||
"google/gemma-2-9b", | ||||||||||
torch_dtype=torch.bfloat16, | ||||||||||
device_map="auto", | ||||||||||
attn_implementation="sdpa" | ||||||||||
) | ||||||||||
|
||||||||||
The abstract from the blog post is the following: | ||||||||||
input_text = "Explain quantum computing simply." | ||||||||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") | ||||||||||
|
||||||||||
*Now we’re officially releasing Gemma 2 to researchers and developers globally. Available in both 9 billion (9B) and 27 billion (27B) parameter sizes, Gemma 2 is higher-performing and more efficient at inference than the first generation, with significant safety advancements built in. In fact, at 27B, it offers competitive alternatives to models more than twice its size, delivering the kind of performance that was only possible with proprietary models as recently as December.* | ||||||||||
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static") | ||||||||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | ||||||||||
|
||||||||||
Tips: | ||||||||||
``` | ||||||||||
|
||||||||||
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py` | ||||||||||
</hfoption> | ||||||||||
<hfoption id="transformers-cli"> | ||||||||||
|
||||||||||
``` | ||||||||||
echo -e "Explain quantum computing simply." | transformers-cli run --task text-generation --model google/gemma-2-2b --device 0 | ||||||||||
``` | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Close the </hfoption>
</hfoptions> |
||||||||||
</hfoption> | ||||||||||
</hfoptions> | ||||||||||
|
||||||||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. | ||||||||||
|
||||||||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4. | ||||||||||
|
||||||||||
```python | ||||||||||
import torch | ||||||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | ||||||||||
|
||||||||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True) | ||||||||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b") | ||||||||||
model = AutoModelForCausalLM.from_pretrained( | ||||||||||
"google/gemma-2-27b", | ||||||||||
torch_dtype=torch.bfloat16, | ||||||||||
device_map="auto", | ||||||||||
attn_implementation="sdpa" | ||||||||||
) | ||||||||||
|
||||||||||
input_text = "Explain quantum computing simply." | ||||||||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") | ||||||||||
|
||||||||||
outputs = model.generate(**input_ids, max_new_tokens=32, cache_implementation="static") | ||||||||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | ||||||||||
``` | ||||||||||
|
||||||||||
Use the [AttentionMaskVisualizer](https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139) to better understand what tokens the model can and cannot attend to. | ||||||||||
|
||||||||||
|
||||||||||
```python | ||||||||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer | ||||||||||
visualizer = AttentionMaskVisualizer("google/gemma-2b") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you print out the image, upload it to https://huggingface.co/datasets/huggingface/documentation-images/tree/main/transformers/model_doc (ping me to merge!), and then add it here? |
||||||||||
visualizer("You are an assistant. Make sure you print me") | ||||||||||
``` | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add image below here: <div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/gemma-2-attn-mask.png"/>
</div> Let's also add a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @stevhliu, Could you please verify, if this attention visualization image is correct or not? If so, I will push this on huggingface/documentation-images. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that looks correct! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have created a PR, please have a look at it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, once all the other comments have been addressed we can merge :) |
||||||||||
|
||||||||||
<div class="flex justify-center"> | ||||||||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/gemma-2-attn-mask.png"/> | ||||||||||
</div> | ||||||||||
|
||||||||||
<Tip warning={true}> | ||||||||||
## Notes | ||||||||||
|
||||||||||
- Gemma2 uses sliding window attention every second layer, which makes it unsuitable for typical kv caching with [`~DynamicCache`] or tuples of tensors. To enable caching in Gemma2 forward call, you must initialize a [`~HybridCache`] instance and pass it as `past_key_values` to the forward call. Note, that you also have to prepare `cache_position` if the `past_key_values` already contains previous keys and values. | ||||||||||
- Use a [`HybridCache`] instance to enable caching in Gemma 2. Gemma 2 doesn't support kv-caching strategies like [`DynamicCache`] or tuples of tensors because it uses sliding window attention every second layer. | ||||||||||
|
||||||||||
</Tip> | ||||||||||
```python | ||||||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache | ||||||||||
|
||||||||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen](). | ||||||||||
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") | ||||||||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") | ||||||||||
|
||||||||||
inputs = tokenizer(text="My name is Gemma", return_tensors="pt") | ||||||||||
max_generated_length = inputs.input_ids.shape[1] + 10 | ||||||||||
past_key_values = HybridCache(config=model.config, max_batch_size=1, | ||||||||||
max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) | ||||||||||
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | ||||||||||
``` | ||||||||||
|
||||||||||
## Gemma2Config | ||||||||||
|
||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The badges should go above
# Gemma2