Skip to content

Commit 0013fb0

Browse files
committed
[WEB] Add shark-web logging
1. This commit adds support to display logs in the shark-web. 2. It also adds nod logo in the home page. 3. Stable-diffusion outputs are being saved now. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
1 parent 56f8a0d commit 0013fb0

File tree

9 files changed

+135
-35
lines changed

9 files changed

+135
-35
lines changed

web/Nod_logo.jpg

40.6 KB
Loading

web/index.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,81 @@
44

55
# from models.diffusion.v_diffusion import vdiff_inf
66
import gradio as gr
7+
from PIL import Image
8+
9+
10+
def debug_event(debug):
11+
return gr.Textbox.update(visible=debug)
12+
713

814
with gr.Blocks() as shark_web:
9-
gr.Markdown("Shark Models Demo.")
10-
with gr.Tabs():
1115

16+
with gr.Row():
17+
with gr.Group():
18+
with gr.Column(scale=1):
19+
img = Image.open("./Nod_logo.jpg")
20+
gr.Image(value=img, show_label=False, interactive=False).style(
21+
height=70, width=70
22+
)
23+
with gr.Column(scale=9):
24+
gr.Label(value="Shark Models Demo.")
25+
26+
with gr.Tabs():
1227
with gr.TabItem("ResNet50"):
13-
image = device = resnet = output = None
28+
image = device = debug = resnet = output = std_output = None
1429
with gr.Row():
1530
with gr.Column(scale=1, min_width=600):
1631
image = gr.Image(label="Image")
1732
device = gr.Textbox(label="Device", value="cpu")
33+
debug = gr.Checkbox(label="DEBUG", value=False)
1834
resnet = gr.Button("Recognize Image").style(
1935
full_width=True
2036
)
2137
with gr.Column(scale=1, min_width=600):
2238
output = gr.Label(label="Output")
2339
std_output = gr.Textbox(
24-
label="Std Output", value="Nothing."
40+
label="Std Output",
41+
value="Nothing to show.",
42+
visible=False,
2543
)
44+
debug.change(
45+
debug_event,
46+
inputs=[debug],
47+
outputs=[std_output],
48+
show_progress=False,
49+
)
2650
resnet.click(
2751
resnet_inf,
2852
inputs=[image, device],
2953
outputs=[output, std_output],
3054
)
3155

3256
with gr.TabItem("Albert MaskFill"):
33-
masked_text = device = albert_mask = decoded_res = None
57+
masked_text = (
58+
device
59+
) = debug = albert_mask = decoded_res = std_output = None
3460
with gr.Row():
3561
with gr.Column(scale=1, min_width=600):
3662
masked_text = gr.Textbox(
3763
label="Masked Text",
3864
placeholder="Give me a sentence with [MASK] to fill",
3965
)
4066
device = gr.Textbox(label="Device", value="cpu")
67+
debug = gr.Checkbox(label="DEBUG", value=False)
4168
albert_mask = gr.Button("Decode Mask")
4269
with gr.Column(scale=1, min_width=600):
4370
decoded_res = gr.Label(label="Decoded Results")
4471
std_output = gr.Textbox(
45-
label="Std Output", value="Nothing."
72+
label="Std Output",
73+
value="Nothing to show.",
74+
visible=False,
4675
)
76+
debug.change(
77+
debug_event,
78+
inputs=[debug],
79+
outputs=[std_output],
80+
show_progress=False,
81+
)
4782
albert_mask.click(
4883
albert_maskfill_inf,
4984
inputs=[masked_text, device],
@@ -74,28 +109,33 @@
74109
with gr.TabItem("Stable-Diffusion"):
75110
prompt = (
76111
iters
77-
) = mlir_loc = device = stable_diffusion = generated_img = None
112+
) = (
113+
device
114+
) = debug = stable_diffusion = generated_img = std_output = None
78115
with gr.Row():
79116
with gr.Column(scale=1, min_width=600):
80117
prompt = gr.Textbox(
81118
label="Prompt",
82119
value="a photograph of an astronaut riding a horse",
83120
)
84121
iters = gr.Number(label="Steps", value=2)
85-
mlir_loc = gr.Textbox(
86-
label="Location of MLIR(Relative to SHARK/web/)",
87-
value="./stable_diffusion.mlir",
88-
)
89122
device = gr.Textbox(label="Device", value="vulkan")
123+
debug = gr.Checkbox(label="DEBUG", value=False)
90124
stable_diffusion = gr.Button("Generate image from prompt")
91125
with gr.Column(scale=1, min_width=600):
92126
generated_img = gr.Image(type="pil", shape=(100, 100))
93127
std_output = gr.Textbox(
94-
label="Std Output", value="Nothing."
128+
label="Std Output", value="Nothing.", visible=False
95129
)
130+
debug.change(
131+
debug_event,
132+
inputs=[debug],
133+
outputs=[std_output],
134+
show_progress=False,
135+
)
96136
stable_diffusion.click(
97137
stable_diff_inf,
98-
inputs=[prompt, iters, mlir_loc, device],
138+
inputs=[prompt, iters, device],
99139
outputs=[generated_img, std_output],
100140
)
101141

web/logs/albert_maskfill_log.txt

Whitespace-only changes.

web/logs/resnet50_log.txt

Whitespace-only changes.

web/logs/stable_diffusion_log.txt

Whitespace-only changes.

web/models/albert_maskfill.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def forward(self, input_ids, attention_mask):
2121

2222
################################## Preprocessing inputs ####################
2323

24+
DEBUG = False
2425
compiled_module = {}
2526
compiled_module["tokenizer"] = AutoTokenizer.from_pretrained("albert-base-v2")
2627

@@ -42,10 +43,13 @@ def preprocess_data(text):
4243
return inputs
4344

4445

45-
def top5_possibilities(text, inputs, token_logits):
46+
def top5_possibilities(text, inputs, token_logits, log_write):
4647

48+
global DEBUG
4749
global compiled_module
4850

51+
if DEBUG:
52+
log_write.write("Retrieving top 5 possible outcomes.\n")
4953
tokenizer = compiled_module["tokenizer"]
5054
mask_id = torch.where(inputs[0] == tokenizer.mask_token_id)[1]
5155
mask_token_logits = token_logits[0, mask_id, :]
@@ -55,6 +59,8 @@ def top5_possibilities(text, inputs, token_logits):
5559
for token in top_5_tokens:
5660
label = text.replace(tokenizer.mask_token, tokenizer.decode(token))
5761
top5[label] = percentage[token].item()
62+
if DEBUG:
63+
log_write.write("Done.\n")
5864
return top5
5965

6066

@@ -63,10 +69,18 @@ def top5_possibilities(text, inputs, token_logits):
6369

6470
def albert_maskfill_inf(masked_text, device):
6571

72+
global DEBUG
6673
global compiled_module
6774

75+
DEBUG = False
76+
log_write = open(r"logs/albert_maskfill_log.txt", "w")
77+
if log_write:
78+
DEBUG = True
79+
6880
inputs = preprocess_data(masked_text)
6981
if device not in compiled_module.keys():
82+
if DEBUG:
83+
log_write.write("Compiling the Albert Maskfill module.\n")
7084
mlir_importer = SharkImporter(
7185
AlbertModule(),
7286
inputs,
@@ -80,6 +94,15 @@ def albert_maskfill_inf(masked_text, device):
8094
)
8195
shark_module.compile()
8296
compiled_module[device] = shark_module
97+
if DEBUG:
98+
log_write.write("Compilation successful.\n")
8399

84100
token_logits = torch.tensor(compiled_module[device].forward(inputs))
85-
return top5_possibilities(masked_text, inputs, token_logits), "Testing.."
101+
output = top5_possibilities(masked_text, inputs, token_logits, log_write)
102+
log_write.close()
103+
104+
std_output = ""
105+
with open(r"logs/albert_maskfill_log.txt", "r") as log_read:
106+
std_output = log_read.read()
107+
108+
return output, std_output

web/models/resnet50.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
################################## Preprocessing inputs and helper functions ########
99

10+
DEBUG = False
11+
compiled_module = {}
12+
1013

1114
def preprocess_image(img):
1215
image = Image.fromarray(img)
@@ -33,43 +36,57 @@ def load_labels():
3336
return labels
3437

3538

36-
def top3_possibilities(res):
39+
def top3_possibilities(res, log_write):
40+
41+
global DEBUG
42+
43+
if DEBUG:
44+
log_write.write("Retrieving top 3 possible outcomes.\n")
3745
labels = load_labels()
3846
_, indexes = torch.sort(res, descending=True)
3947
percentage = torch.nn.functional.softmax(res, dim=1)[0]
4048
top3 = dict(
4149
[(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]]
4250
)
51+
if DEBUG:
52+
log_write.write("Done.\n")
4353
return top3
4454

4555

4656
##############################################################################
4757

48-
compiled_module = {}
49-
5058

5159
def resnet_inf(numpy_img, device):
5260

61+
global DEBUG
5362
global compiled_module
5463

55-
std_output = ""
64+
DEBUG = False
65+
log_write = open(r"logs/resnet50_log.txt", "w")
66+
if log_write:
67+
DEBUG = True
5668

5769
if device not in compiled_module.keys():
58-
std_output += "Compiling the Resnet50 module.\n"
70+
if DEBUG:
71+
log_write.write("Compiling the Resnet50 module.\n")
5972
mlir_model, func_name, inputs, golden_out = download_torch_model(
6073
"resnet50"
6174
)
62-
6375
shark_module = SharkInference(
6476
mlir_model, func_name, device=device, mlir_dialect="linalg"
6577
)
6678
shark_module.compile()
67-
std_output += "Compilation successful.\n"
6879
compiled_module[device] = shark_module
80+
if DEBUG:
81+
log_write.write("Compilation successful.\n")
6982

7083
img = preprocess_image(numpy_img)
7184
result = compiled_module[device].forward((img.detach().numpy(),))
85+
output = top3_possibilities(torch.from_numpy(result), log_write)
86+
log_write.close()
87+
88+
std_output = ""
89+
with open(r"logs/resnet50_log.txt", "r") as log_read:
90+
std_output = log_read.read()
7291

73-
# print("The top 3 results obtained via shark_runner is:")
74-
std_output += "Retrieving top 3 possible outcomes.\n"
75-
return top3_possibilities(torch.from_numpy(result)), std_output
92+
return output, std_output

web/models/stable_diffusion.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@
1010
import torch_mlir
1111
import tempfile
1212
import numpy as np
13+
import os
1314

1415
##############################################################################
1516

1617

1718
def load_mlir(mlir_loc):
18-
import os
19-
2019
if mlir_loc == None:
2120
return None
22-
print(f"Trying to load the model from {mlir_loc}.")
2321
with open(os.path.join(mlir_loc)) as f:
2422
mlir_module = f.read()
2523
return mlir_module
@@ -85,21 +83,30 @@ def strip_overloads(gm):
8583

8684
##############################################################################
8785

86+
DEBUG = False
8887
compiled_module = {}
8988

9089

91-
def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
90+
def stable_diff_inf(prompt: str, steps, device: str):
9291

9392
args = {}
9493
args["prompt"] = [prompt]
9594
args["steps"] = steps
9695
args["device"] = device
97-
args["mlir_loc"] = mlir_loc
96+
args["mlir_loc"] = "./stable_diffusion.mlir"
97+
output_loc = (
98+
f"stored_results/stable_diffusion/{prompt}_{int(steps)}_{device}.jpg"
99+
)
98100

101+
global DEBUG
99102
global compiled_module
100103

101-
if args["device"] not in compiled_module.keys():
104+
DEBUG = False
105+
log_write = open(r"logs/stable_diffusion_log.txt", "w")
106+
if log_write:
107+
DEBUG = True
102108

109+
if args["device"] not in compiled_module.keys():
103110
YOUR_TOKEN = "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"
104111

105112
# 1. Load the autoencoder model which will be used to decode the latents into image space.
@@ -116,6 +123,8 @@ def stable_diff_inf(prompt: str, steps, mlir_loc: str, device: str):
116123
compiled_module["text_encoder"] = CLIPTextModel.from_pretrained(
117124
"openai/clip-vit-large-patch14"
118125
)
126+
if DEBUG:
127+
log_write.write("Compiling the Unet module.\n")
119128

120129
# Wrap the unet model to return tuples.
121130
class UnetModel(torch.nn.Module):
@@ -143,14 +152,16 @@ def forward(self, x, y, z):
143152
args["mlir_loc"],
144153
)
145154
compiled_module[args["device"]] = shark_unet
155+
if DEBUG:
156+
log_write.write("Compilation successful.\n")
146157

158+
compiled_module["unet"] = unet
147159
compiled_module["scheduler"] = LMSDiscreteScheduler(
148160
beta_start=0.00085,
149161
beta_end=0.012,
150162
beta_schedule="scaled_linear",
151163
num_train_timesteps=1000,
152164
)
153-
compiled_module["unet"] = unet
154165

155166
shark_unet = compiled_module[args["device"]]
156167
vae = compiled_module["vae"]
@@ -202,7 +213,8 @@ def forward(self, x, y, z):
202213

203214
for i, t in tqdm(enumerate(scheduler.timesteps)):
204215

205-
print(f"i = {i} t = {t}")
216+
if DEBUG:
217+
log_write.write(f"i = {i} t = {t}\n")
206218
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
207219
latent_model_input = torch.cat([latents] * 2)
208220
sigma = scheduler.sigmas[i]
@@ -232,11 +244,19 @@ def forward(self, x, y, z):
232244

233245
# scale and decode the image latents with vae
234246
latents = 1 / 0.18215 * latents
235-
print(latents.shape)
236247
image = vae.decode(latents).sample
237248

238249
image = (image / 2 + 0.5).clamp(0, 1)
239250
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
240251
images = (image * 255).round().astype("uint8")
241252
pil_images = [Image.fromarray(image) for image in images]
242-
return pil_images[0], "Testing.."
253+
output = pil_images[0]
254+
# save the output image with the prompt name.
255+
output.save(os.path.join(output_loc))
256+
log_write.close()
257+
258+
std_output = ""
259+
with open(r"logs/stable_diffusion_log.txt", "r") as log_read:
260+
std_output = log_read.read()
261+
262+
return output, std_output

web/stored_results/stable_diffusion/empty.jpg

Loading

0 commit comments

Comments
 (0)