Skip to content

Commit 5a255ea

Browse files
authored
stable_softmax, wanb_entity, visible discord, replace buggy colab (#320)
* update vae.py for new f8 gumbel vqgan * Download and cache gumbel vae per flag. * Download correct 8 bit gumbel url. * tiny fix to backend code. * add correct urls for gumbel vqgan * rearrange codebook indices if using gumbel. * Fix gumbel decode() as well * Fix decode for GumbelVQ * add `--stable_softmax` for fp16/amp training * Pytorch LTS CUDA 10.2 Builds all deep speed ops. * `--wandb_entity` arg * Feature discord, replace buggy notebook, quick start link * Fix discord widget * Revert header to original * ditch `latest best` idea, rearrange header The "best currently trained model" idea was good - but there's clearly no way we can keep the README up to date on something like that. * formatting, add links, add my latest checkpoint added @rom1504 awesome dalle pseudo-serverless web frontend/backend provider. added generations from my most recent open ai blog checkpoint. added mega b's colab notebook for running inference on that checkpoint. fixed some bolding and other formatting issues per @rom1504's suggestion. decreased image width on a few images for the sake of scrollability. * Remove unnecessary download/setup. Co-authored-by: Sam Sepiol <>
1 parent 7eb2e34 commit 5a255ea

File tree

4 files changed

+50
-24
lines changed

4 files changed

+50
-24
lines changed

README.md

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,58 @@
1-
<img src="./images/birds.png" width="500px"></img>
1+
# DALL-E in Pytorch
22

3-
** current best, trained by <a href="https://github.com/kobiso">Kobiso</a> **
3+
<p align='center'>
4+
<a href="https://colab.research.google.com/gist/afiaka87/b29213684a1dd633df20cab49d05209d/train_dalle_pytorch.ipynb">
5+
<img alt="Train DALL-E w/ DeepSpeed" src="https://colab.research.google.com/assets/colab-badge.svg">
6+
</a>
7+
<a href="https://discord.gg/dall-e"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a></br>
8+
<a href="https://github.com/robvanvolt/DALLE-models">Released DALLE Models</a></br>
9+
<a href="https://github.com/rom1504/dalle-service">Web-Hostable DALLE Checkpoints</a></br>
410

5-
## DALL-E in Pytorch
11+
<a href="https://www.youtube.com/watch?v=j4xgkjWlfL4">Yannic Kilcher's video</a>
12+
<p>
13+
Implementation / replication of <a href="https://openai.com/blog/dall-e/">DALL-E</a> (<a href="https://arxiv.org/abs/2102.12092">paper</a>), OpenAI's Text to Image Transformer, in Pytorch. It will also contain <a href="https://openai.com/blog/clip/">CLIP</a> for ranking the generations.
614

7-
Implementation / replication of <a href="https://openai.com/blog/dall-e/">DALL-E</a> (<a href="https://arxiv.org/abs/2102.12092">paper</a>), OpenAI's Text to Image Transformer, in Pytorch. It will also contain <a href="https://openai.com/blog/clip/">CLIP</a> for ranking the generations.
15+
---
816

9-
<a href="https://github.com/sdtblck">Sid</a>, <a href="http://github.com/kingoflolz">Ben</a>, and <a href="https://github.com/AranKomat">Aran</a> over at <a href="https://www.eleuther.ai/">Eleuther AI</a> are working on <a href="https://github.com/EleutherAI/DALLE-mtf">DALL-E for Mesh Tensorflow</a>! Please lend them a hand if you would like to see DALL-E trained on TPUs.
1017

11-
<a href="https://www.youtube.com/watch?v=j4xgkjWlfL4">Yannic Kilcher's video</a>
1218

13-
Before we replicate this, we can settle for <a href="https://github.com/lucidrains/deep-daze">Deep Daze</a> or <a href="https://github.com/lucidrains/big-sleep">Big Sleep</a>
19+
[Quick Start](https://github.com/lucidrains/DALLE-pytorch/wiki)
1420

15-
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1dWvA54k4fH8zAmiix3VXbg95uEIMfqQM?usp=sharing) Train in Colab
21+
<a href="https://github.com/lucidrains/deep-daze">Deep Daze</a> or <a href="https://github.com/lucidrains/big-sleep">Big Sleep</a> are great alternatives!
1622

1723
## Status
24+
<p align='center'>
1825

1926
- <a href="https://github.com/htoyryla">Hannu</a> has managed to train a small 6 layer DALL-E on a dataset of just 2000 landscape images! (2048 visual tokens)
2027

2128
<img src="./images/landscape.png"></img>
2229

2330
- <a href="https://github.com/kobiso">Kobiso</a>, a research engineer from Naver, has trained on the CUB200 dataset <a href="https://github.com/lucidrains/DALLE-pytorch/discussions/131">here</a>, using full and deepspeed sparse attention
24-
- <a href="https://github.com/afiaka87">afiaka87</a> has managed one epoch using a 32 layer reversible DALL-E <a href="https://github.com/lucidrains/DALLE-pytorch/issues/86#issue-832121328">here</a>
25-
- <a href="https://github.com/robvanvolt">robvanvolt</a> has started a <a href="https://discord.gg/UhR4kKCSp6">Discord channel</a> for replication efforts
26-
27-
- <a href="https://github.com/robvanvolt">TheodoreGalanos</a> has trained on 150k layouts with the following results
2831

29-
<img src="./images/layouts-1.jpg" width="400px"></img>
32+
<img src="./images/birds.png" width="256"></img>
3033

31-
<img src="./images/layouts-2.jpg" width="400px"></img>
34+
- (3/15/21) <a href="https://github.com/afiaka87">afiaka87</a> has managed one epoch using a reversible DALL-E and the dVaE <a href="https://github.com/lucidrains/DALLE-pytorch/issues/86#issue-832121328">here</a>
3235

36+
- <a href="https://github.com/robvanvolt">TheodoreGalanos</a> has trained on 150k layouts with the following results
37+
<p>
38+
<img src="./images/layouts-1.jpg" width="256"></img>
39+
<img src="./images/layouts-2.jpg" width="256"></img>
40+
</p>
3341
- <a href="https://github.com/rom1504">Rom1504</a> has trained on 50k fashion images with captions with a really small DALL-E (2 layers) for just 24 hours with the following results
34-
35-
<img src="./images/clothing.png" width="500px"></img>
36-
42+
<p/>
43+
<img src="./images/clothing.png" width="420"></img>
44+
45+
- <a href="https://github.com/afiaka87">afiaka87</a> trained for 6 epochs on the same dataset as before thanks to the efficient 16k VQGAN with the following <a href="https://github.com/lucidrains/DALLE-pytorch/discussions/322>discussion">results</a>
46+
47+
<p align='centered'>
48+
<img src="https://user-images.githubusercontent.com/3994972/123564891-b6f18780-d780-11eb-9019-8a1b6178f861.png" width="420" alt-text='a photo of westwood park, san francisco, from the water in the afternoon'></img>
49+
<img src="https://user-images.githubusercontent.com/3994972/123564776-4c404c00-d780-11eb-9c8e-3356df358df3.png" width="420" alt-text='a female mannequin dressed in an olive button-down shirt and gold palazzo pants'> </img>
50+
</p>
51+
52+
Thanks to the amazing "mega b#6696" you can generate from this checkpoint in colab -
53+
<a href="https://colab.research.google.com/drive/11V2xw1eLPfZvzW8UQyTUhqCEU71w6Pr4?usp=sharing">
54+
<img alt="Run inference on the Afiaka checkpoint in Colab" src="https://colab.research.google.com/assets/colab-badge.svg">
55+
</a>
3756
## Install
3857

3958
```bash

dalle_pytorch/distributed_backends/distributed_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def check_batch_size(self, batch_size):
5959
(f"batch size can't be smaller than number of processes "
6060
f'({batch_size} < {self.get_world_size()})')
6161

62-
def wrap_arg_parser(parser):
62+
def wrap_arg_parser(self, parser):
6363
"""Add arguments to support optional distributed backend usage."""
6464
raise NotImplementedError
6565

docker/Dockerfile

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11

2-
ARG IMG_TAG=1.6.0-cuda10.1-cudnn7-devel
2+
ARG IMG_TAG=1.8.1-cuda10.2-cudnn7-devel
33
ARG IMG_REPO=pytorch
4-
ARG BRANCH=main
5-
ARG REMOTE=lucidrains
64

75
FROM pytorch/$IMG_REPO:$IMG_TAG
86

9-
RUN apt-get -y update && apt-get -y install git gcc llvm-9-dev cmake libaio-dev
7+
RUN apt-get -y update && apt-get -y install git gcc llvm-9-dev cmake libaio-dev vim wget
108

119
RUN git clone https://github.com/microsoft/DeepSpeed.git /tmp/DeepSpeed
12-
RUN cd /tmp/DeepSpeed && DS_BUILD_SPARSE_ATTN=1 ./install.sh -r
10+
RUN cd /tmp/DeepSpeed && DS_BUILD_OPS=1 ./install.sh -r
1311
RUN pip install git+https://github.com/lucidrains/DALLE-pytorch.git
1412

1513
WORKDIR dalle

train_dalle.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@
8080
parser.add_argument('--wandb_name', default='dalle_train_transformer',
8181
help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')
8282

83+
parser.add_argument('--wandb_entity', default=None,
84+
help='(optional) Name of W&B team/entity to log to.')
85+
86+
parser.add_argument('--stable_softmax', dest='stable_softmax', action='store_true',
87+
help='Prevent values from becoming too large during softmax. Helps with stability in fp16 and Mixture of Quantization training.')
88+
8389
parser = distributed_utils.wrap_arg_parser(parser)
8490

8591
train_group = parser.add_argument_group('Training settings')
@@ -176,6 +182,7 @@ def cp_path_to_dir(cp_path, tag):
176182
LOSS_IMG_WEIGHT = args.loss_img_weight
177183
FF_DROPOUT = args.ff_dropout
178184
ATTN_DROPOUT = args.attn_dropout
185+
STABLE = args.stable_softmax
179186

180187
ATTN_TYPES = tuple(args.attn_types.split(','))
181188

@@ -287,6 +294,7 @@ def cp_path_to_dir(cp_path, tag):
287294
attn_types=ATTN_TYPES,
288295
ff_dropout=FF_DROPOUT,
289296
attn_dropout=ATTN_DROPOUT,
297+
stable=STABLE,
290298
)
291299
resume_epoch = 0
292300

@@ -434,7 +442,8 @@ def tokenize(s):
434442
)
435443

436444
run = wandb.init(
437-
project=args.wandb_name, # 'dalle_train_transformer' by default
445+
project=args.wandb_name,
446+
entity=args.wandb_entity,
438447
resume=False,
439448
config=model_config,
440449
)

0 commit comments

Comments
 (0)