Skip to content

Commit a72f1c9

Browse files
stancldPhungVanDuysguggerpatrickvonplatenpatil-suraj
authored
Add LongT5 model (#16792)
* Initial commit * Make some fixes * Make PT model full forward pass * Drop TF & Flax implementation, fix copies etc * Add Flax model and update some corresponding stuff * Drop some TF things * Update config and flax local attn * Add encoder_attention_type to config * . * Update docs * Do some cleansing * Fix some issues -> make style; add some docs * Fix position_bias + mask addition + Update tests * Fix repo consistency * Fix model consistency by removing flax operation over attn_mask * [WIP] Add PT TGlobal LongT5 * . * [WIP] Add flax tglobal model * [WIP] Update flax model to use the right attention type in the encoder * Fix flax tglobal model forward pass * Make the use of global_relative_attention_bias * Add test suites for TGlobal model * Fix minor bugs, clean code * Fix pt-flax equivalence though not convinced with correctness * Fix LocalAttn implementation to match the original impl. + update READMEs * Few updates * Update: [Flax] improve large model init and loading #16148 * Add ckpt conversion script accoring to #16853 + handle torch device placement * Minor updates to conversion script. * Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM * gpu support + dtype fix * Apply some suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * * Remove (de)parallelize stuff * Edit shape comments * Update README.md * make fix-copies * Remove caching logic for local & tglobal attention * Apply another batch of suggestions from code review * Add missing checkpoints * Format converting scripts * Drop (de)parallelize links from longT5 mdx * Fix converting script + revert config file change * Revert "Remove caching logic for local & tglobal attention" This reverts commit 2a61982. * Stash caching logic in Flax model * Make side relative bias used always * Drop caching logic in PT model * Return side bias as it was * Drop all remaining model parallel logic * Remove clamp statements * Move test files to the proper place * Update docs with new version of hf-doc-builder * Fix test imports * Make some minor improvements * Add missing checkpoints to docs * Make TGlobal model compatible with torch.onnx.export * Replace some np.ndarray with jnp.ndarray * Fix TGlobal for ONNX conversion + update docs * fix _make_global_fixed_block_ids and masked neg value * update flax model * style and quality * fix imports * remove load_tf_weights_in_longt5 from init and fix copies * add slow test for TGlobal model * typo fix * Drop obsolete is_parallelizable and one warning * Update __init__ files to fix repo-consistency * fix pipeline test * Fix some device placements * [wip]: Update tests -- need to generate summaries to update expected_summary * Fix quality * Update LongT5 model card * Update (slow) summarization tests * make style * rename checkpoitns * finish * fix flax tests Co-authored-by: phungvanduy <pvduy23@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: patil-suraj <surajp815@gmail.com>
1 parent 1690094 commit a72f1c9

30 files changed

+7389
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
284284
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
285285
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
286286
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
287+
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
287288
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
288289
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
289290
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.

README_ko.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
265265
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
266266
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
267267
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
268+
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
268269
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
269270
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
270271
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.

README_zh-hans.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ conda install -c huggingface transformers
289289
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。
290290
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (来自 Meta AI) 伴随论文 [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) 由 Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze 发布。
291291
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (来自 AllenAI) 伴随论文 [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) 由 Iz Beltagy, Matthew E. Peters, Arman Cohan 发布。
292+
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (来自 Google AI) released 伴随论文 [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) 由 Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang 发布。
292293
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (来自 Studio Ousia) 伴随论文 [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) 由 Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto 发布。
293294
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (来自 UNC Chapel Hill) 伴随论文 [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) 由 Hao Tan and Mohit Bansal 发布。
294295
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (来自 Facebook) 伴随论文 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 由 Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 发布。

README_zh-hant.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ conda install -c huggingface transformers
301301
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
302302
1. **[LeViT](https://huggingface.co/docs/transformers/main/model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
303303
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
304+
1. **[LongT5](https://huggingface.co/docs/transformers/main/model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
304305
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
305306
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
306307
1. **[M-CTC-T](https://huggingface.co/docs/transformers/main/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@
256256
title: LeViT
257257
- local: model_doc/longformer
258258
title: Longformer
259+
- local: model_doc/longt5
260+
title: LongT5
259261
- local: model_doc/luke
260262
title: LUKE
261263
- local: model_doc/lxmert

docs/source/en/index.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ The library currently contains JAX, PyTorch and TensorFlow implementations, pret
107107
1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
108108
1. **[LeViT](model_doc/levit)** (from Meta AI) released with the paper [LeViT: A Vision Transformer in ConvNet's Clothing for Faster Inference](https://arxiv.org/abs/2104.01136) by Ben Graham, Alaaeldin El-Nouby, Hugo Touvron, Pierre Stock, Armand Joulin, Hervé Jégou, Matthijs Douze.
109109
1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
110+
1. **[LongT5](model_doc/longt5)** (from Google AI) released with the paper [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung, Yinfei Yang.
110111
1. **[LUKE](model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
111112
1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
112113
1. **[M-CTC-T](model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
@@ -233,6 +234,7 @@ Flax), PyTorch, and/or TensorFlow.
233234
| LED | | | | | |
234235
| LeViT | | | | | |
235236
| Longformer | | | | | |
237+
| LongT5 | | | | | |
236238
| LUKE | | | | | |
237239
| LXMERT | | | | | |
238240
| M-CTC-T | | | | | |

docs/source/en/model_doc/longt5.mdx

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# LongT5
14+
15+
## Overview
16+
17+
The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916)
18+
by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo Ni, Yun-Hsuan Sung and Yinfei Yang. It's an
19+
encoder-decoder transformer pre-trained in a text-to-text denoising generative setting. LongT5 model is an extension of
20+
T5 model, and it enables using one of the two different efficient attention mechanisms - (1) Local attention, or (2)
21+
Transient-Global attention.
22+
23+
24+
The abstract from the paper is the following:
25+
26+
*Recent work has shown that either (1) increasing the input length or (2) increasing model size can improve the
27+
performance of Transformer-based neural models. In this paper, we present a new model, called LongT5, with which we
28+
explore the effects of scaling both the input length and model size at the same time. Specifically, we integrated
29+
attention ideas from long-input transformers (ETC), and adopted pre-training strategies from summarization pre-training
30+
(PEGASUS) into the scalable T5 architecture. The result is a new attention mechanism we call {\em Transient Global}
31+
(TGlobal), which mimics ETC's local/global attention mechanism, but without requiring additional side-inputs. We are
32+
able to achieve state-of-the-art results on several summarization tasks and outperform the original T5 models on
33+
question answering tasks.*
34+
35+
Tips:
36+
37+
- [`LongT5ForConditionalGeneration`] is an extension of [`T5ForConditionalGeneration`] exchanging the traditional
38+
encoder *self-attention* layer with efficient either *local* attention or *transient-global* (*tglobal*) attention.
39+
- Unlike the T5 model, LongT5 does not use a task prefix. Furthermore, it uses a different pre-training objective
40+
inspired by the pre-training of `[PegasusForConditionalGeneration]`.
41+
- LongT5 model is designed to work efficiently and very well on long-range *sequence-to-sequence* tasks where the
42+
input sequence exceeds commonly used 512 tokens. It is capable of handling input sequences of a length up to 16,384 tokens.
43+
- For *Local Attention*, the sparse sliding-window local attention operation allows a given token to attend only `r`
44+
tokens to the left and right of it (with `r=127` by default). *Local Attention* does not introduce any new parameters
45+
to the model. The complexity of the mechanism is linear in input sequence length `l`: `O(l*r)`.
46+
- *Transient Global Attention* is an extension of the *Local Attention*. It, furthermore, allows each input token to
47+
interact with all other tokens in the layer. This is achieved via splitting an input sequence into blocks of a fixed
48+
length `k` (with a default `k=16`). Then, a global token for such a block is obtained via summing and normalizing the embeddings of every token
49+
in the block. Thanks to this, the attention allows each token to attend to both nearby tokens like in Local attention, and
50+
also every global token like in the case of standard global attention (*transient* represents the fact the global tokens
51+
are constructed dynamically within each attention operation). As a consequence, *TGlobal* attention introduces
52+
a few new parameters -- global relative position biases and a layer normalization for global token's embedding.
53+
The complexity of this mechanism is `O(l(r + l/k))`.
54+
- An example showing how to evaluate a fine-tuned LongT5 model on the [pubmed dataset](https://huggingface.co/datasets/scientific_papers) is below.
55+
56+
```python
57+
>>> import evaluate
58+
>>> from datasets import load_dataset
59+
>>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
60+
61+
>>> dataset = load_dataset("scientific_papers", "pubmed", split="validation")
62+
>>> model = (
63+
... LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
64+
... .to("cuda")
65+
... .half()
66+
... )
67+
>>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
68+
69+
70+
>>> def generate_answers(batch):
71+
... inputs_dict = tokenizer(
72+
... batch["article"], max_length=16384, padding="max_length", truncation=True, return_tensors="pt"
73+
... )
74+
... input_ids = inputs_dict.input_ids.to("cuda")
75+
... attention_mask = inputs_dict.attention_mask.to("cuda")
76+
... output_ids = model.generate(input_ids, attention_mask=attention_mask, max_length=512, num_beams=2)
77+
... batch["predicted_abstract"] = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
78+
... return batch
79+
80+
81+
>>> result = dataset.map(generate_answer, batched=True, batch_size=2)
82+
>>> rouge = evaluate.load("rouge")
83+
>>> rouge.compute(predictions=result["predicted_abstract"], references=result["abstract"])
84+
```
85+
86+
This model was contributed by [stancld](https://huggingface.co/stancld).
87+
The original code can be found [here](https://github.com/google-research/longt5).
88+
89+
90+
## LongT5Config
91+
92+
[[autodoc]] LongT5Config
93+
94+
## LongT5Model
95+
96+
[[autodoc]] LongT5Model
97+
- forward
98+
99+
## LongT5ForConditionalGeneration
100+
101+
[[autodoc]] LongT5ForConditionalGeneration
102+
- forward
103+
104+
## LongT5EncoderModel
105+
106+
[[autodoc]] LongT5EncoderModel
107+
- forward
108+
109+
## FlaxLongT5Model
110+
111+
[[autodoc]] FlaxLongT5Model
112+
- __call__
113+
- encode
114+
- decode
115+
116+
## FlaxLongT5ForConditionalGeneration
117+
118+
[[autodoc]] FlaxLongT5ForConditionalGeneration
119+
- __call__
120+
- encode
121+
- decode

docs/source/en/serialization.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Ready-made configurations include the following architectures:
6666
- GPT-J
6767
- I-BERT
6868
- LayoutLM
69+
- LongT5
6970
- M2M100
7071
- Marian
7172
- mBART

0 commit comments

Comments
 (0)