Skip to content

Commit 7f6de3f

Browse files
committed
extend burgers_DeepONet example
1 parent 206f68d commit 7f6de3f

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

examples/burgers_DeepONet.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ grid = collect(range(0, 1, length=1024))' |> device
3939
# Create the DeepONet:
4040
# IC is given on grid of 1024 points, and we solve for a fixed time t in one
4141
# spatial dimension x, making the branch input of size 1024 and trunk size 1
42-
model = DeepONet((1024,1024,1024),(1,1024,1024))
42+
# We choose GeLU activation for both subnets
43+
model = DeepONet((1024,1024,1024),(1,1024,1024),gelu,gelu) |> device
4344

4445
# We use the ADAM optimizer for training
4546
learning_rate = 0.001
@@ -60,12 +61,3 @@ throttled_cb = throttle(evalcb, 5)
6061

6162
# Do the training loop
6263
Flux.@epochs 500 train!(loss, parameters, [(xtrain,ytrain,grid)], opt, cb = evalcb)
63-
64-
# Accuracy metrics
65-
val_loader = Flux.Data.DataLoader((xtest, ytest), batchsize=1, shuffle=false) |> device
66-
loss = 0.0 |> device
67-
68-
for (x,y) in val_loader
69-
= model(x)
70-
loss += Flux.Losses.mse(ŷ,y)
71-
end

0 commit comments

Comments
 (0)