Skip to content

Commit 0c2c1f7

Browse files
committed
add token shift feature, which should greatly improve convergence. bump to 1.0
1 parent 8fb3fc5 commit 0c2c1f7

File tree

5 files changed

+67
-12
lines changed

5 files changed

+67
-12
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,9 @@ images.shape # (4, 3, 256, 256)
147147
```
148148

149149
You may also want to generate text using DALL-E. For that call this function:
150-
```
151-
text_tokens, texts = dalle.generate_texts(text)
150+
151+
```python
152+
text_tokens, texts = dalle.generate_texts(tokenizer, text)
152153
```
153154

154155
## OpenAI's Pretrained VAE

dalle_pytorch/dalle_pytorch.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from axial_positional_embedding import AxialPositionalEmbedding
88
from einops import rearrange
99

10-
from dalle_pytorch import distributed_utils, tokenizer
10+
from dalle_pytorch import distributed_utils
1111
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
1212
from dalle_pytorch.transformer import Transformer, DivideMax
1313

@@ -322,7 +322,8 @@ def __init__(
322322
sparse_attn = False,
323323
attn_types = None,
324324
loss_img_weight = 7,
325-
stable = False
325+
stable = False,
326+
shift_tokens = True
326327
):
327328
super().__init__()
328329
assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE'
@@ -367,7 +368,8 @@ def __init__(
367368
attn_types = attn_types,
368369
image_fmap_size = image_fmap_size,
369370
sparse_attn = sparse_attn,
370-
stable = stable
371+
stable = stable,
372+
shift_tokens = shift_tokens
371373
)
372374

373375
self.stable = stable
@@ -399,7 +401,8 @@ def __init__(
399401
@eval_decorator
400402
def generate_texts(
401403
self,
402-
text=None,
404+
tokenizer,
405+
text = None,
403406
*,
404407
filter_thres = 0.5,
405408
temperature = 1.
@@ -577,5 +580,3 @@ def forward(
577580

578581
loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
579582
return loss
580-
581-

dalle_pytorch/transformer.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(self, dim, depth, fn):
5252
def forward(self, x, **kwargs):
5353
return self.fn(x, **kwargs) * self.scale
5454

55+
# layer norm
56+
5557
class PreNorm(nn.Module):
5658
def __init__(self, dim, fn):
5759
super().__init__()
@@ -61,6 +63,8 @@ def __init__(self, dim, fn):
6163
def forward(self, x, **kwargs):
6264
return self.fn(self.norm(x), **kwargs)
6365

66+
# feed forward
67+
6468
class GEGLU(nn.Module):
6569
def forward(self, x):
6670
x, gates = x.chunk(2, dim = -1)
@@ -79,6 +83,49 @@ def __init__(self, dim, dropout = 0., mult = 4.):
7983
def forward(self, x):
8084
return self.net(x)
8185

86+
# token shift classes
87+
88+
class PreShiftToken(nn.Module):
89+
def __init__(self, fn, image_size, seq_len):
90+
super().__init__()
91+
self.fn = fn
92+
self.image_size = image_size
93+
self.seq_len = seq_len
94+
95+
def forward(self, x, **kwargs):
96+
n = x.shape[1]
97+
seq_len, image_size = self.seq_len, self.image_size
98+
img_seq_len = image_size ** 2
99+
text_len = seq_len - img_seq_len + 1
100+
padding = seq_len - n + 1
101+
102+
# get text and image tokens
103+
104+
x_text, x_img = x[:, :text_len], x[:, text_len:]
105+
x_img = F.pad(x_img, (0, 0, 0, padding))
106+
x_img = rearrange(x_img, 'b (h w) d -> b h w d', h = image_size)
107+
108+
# shift 1 from the left for text tokens
109+
110+
x_text_shift, x_text_pass = x_text.chunk(2, dim = -1)
111+
x_text_shift = F.pad(x_text_shift, (0, 0, 1, -1))
112+
x_text = torch.cat((x_text_shift, x_text_pass), dim = -1)
113+
114+
# shift from top, left for image tokens
115+
116+
x_img_shift_top, x_img_shift_left, *x_img_pass = x_img.chunk(4, dim = -1)
117+
x_img_shift_left = F.pad(x_img_shift_left, (0, 0, 1, -1))
118+
x_img_shift_top = F.pad(x_img_shift_top, (0, 0, 0, 0, 1, -1))
119+
x_img = torch.cat((x_img_shift_top, x_img_shift_left, *x_img_pass), dim = -1)
120+
121+
# merge text and image sequence back together
122+
123+
x_img = rearrange(x_img, 'b h w d -> b (h w) d')
124+
x = torch.cat((x_text, x_img[:, :-padding]), dim = 1)
125+
return self.fn(x, **kwargs)
126+
127+
# main transformer class
128+
82129
class Transformer(nn.Module):
83130
def __init__(
84131
self,
@@ -96,7 +143,8 @@ def __init__(
96143
attn_types = None,
97144
image_fmap_size = None,
98145
sparse_attn = False,
99-
stable = False
146+
stable = False,
147+
shift_tokens = True
100148
):
101149
super().__init__()
102150
layers = nn.ModuleList([])
@@ -127,9 +175,14 @@ def __init__(
127175
else:
128176
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
129177

178+
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
179+
180+
if shift_tokens:
181+
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
182+
130183
layers.append(nn.ModuleList([
131184
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
132-
LayerScale(dim, ind + 1, PreNorm(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout)))
185+
LayerScale(dim, ind + 1, PreNorm(dim, ff))
133186
]))
134187

135188
execute_type = ReversibleSequence if reversible else SequentialSequence

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def exists(val):
103103

104104
for j, text in tqdm(enumerate(texts)):
105105
if args.gentxt:
106-
text_tokens, gen_texts = dalle.generate_texts(text=text, filter_thres = args.top_k)
106+
text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres = args.top_k)
107107
text = gen_texts[0]
108108
else:
109109
text_tokens = tokenizer.tokenize([text], dalle.text_seq_len).cuda()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '0.14.3',
7+
version = '1.0.0',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)