Skip to content

Commit 206f68d

Browse files
committed
add DeepONet example
1 parent 5678361 commit 206f68d

File tree

3 files changed

+79
-6
lines changed

3 files changed

+79
-6
lines changed

examples/burgers_DeepONet.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using Flux: length, reshape, train!, throttle, @epochs
2+
using OperatorLearning, Flux, MAT
3+
4+
device = cpu;
5+
6+
#=
7+
We would like to implement and train a DeepONet that infers the solution
8+
u(x) of the burgers equation on a grid of 1024 points at time one based
9+
on the initial condition a(x) = u(x,0)
10+
=#
11+
12+
# Read the data from MAT file and store it in a dict
13+
# key "a" is the IC
14+
# key "u" is the desired solution at time 1
15+
vars = matread("burgers_data_R10.mat") |> device
16+
17+
# For trial purposes, we might want to train with different resolutions
18+
# So we sample only every n-th element
19+
subsample = 2^3;
20+
21+
# create the x training array, according to our desired grid size
22+
xtrain = vars["a"][1:1000, 1:subsample:end]' |> device;
23+
# create the x test array
24+
xtest = vars["a"][end-99:end, 1:subsample:end]' |> device;
25+
26+
# Create the y training array
27+
ytrain = vars["u"][1:1000, 1:subsample:end] |> device;
28+
# Create the y test array
29+
ytest = vars["u"][end-99:end, 1:subsample:end] |> device;
30+
31+
# The data is missing grid data, so we create it
32+
# `collect` converts data type `range` into an array
33+
grid = collect(range(0, 1, length=1024))' |> device
34+
35+
# Pass the data to the Flux DataLoader and give it a batch of 20
36+
#train_loader = Flux.Data.DataLoader((xtrain, ytrain), batchsize=20, shuffle=true) |> device
37+
#test_loader = Flux.Data.DataLoader((xtest, ytest), batchsize=20, shuffle=false) |> device
38+
39+
# Create the DeepONet:
40+
# IC is given on grid of 1024 points, and we solve for a fixed time t in one
41+
# spatial dimension x, making the branch input of size 1024 and trunk size 1
42+
model = DeepONet((1024,1024,1024),(1,1024,1024))
43+
44+
# We use the ADAM optimizer for training
45+
learning_rate = 0.001
46+
opt = ADAM(learning_rate)
47+
48+
# Specify the model parameters
49+
parameters = params(model)
50+
51+
# The loss function
52+
# We can't use the "vanilla" implementation of the mse here since we have
53+
# two distinct inputs to our DeepONet, so we wrap them into a tuple
54+
loss(xtrain,ytrain,sensor) = Flux.Losses.mse(model(xtrain,sensor),ytrain)
55+
56+
# Define a callback function that gives some output during training
57+
evalcb() = @show(loss(xtest,ytest,grid))
58+
# Print the callback only every 5 seconds
59+
throttled_cb = throttle(evalcb, 5)
60+
61+
# Do the training loop
62+
Flux.@epochs 500 train!(loss, parameters, [(xtrain,ytrain,grid)], opt, cb = evalcb)
63+
64+
# Accuracy metrics
65+
val_loader = Flux.Data.DataLoader((xtest, ytest), batchsize=1, shuffle=false) |> device
66+
loss = 0.0 |> device
67+
68+
for (x,y) in val_loader
69+
= model(x)
70+
loss += Flux.Losses.mse(ŷ,y)
71+
end

examples/burgers.jl renamed to examples/burgers_FNO.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Flux: length, reshape, train!, @epochs
1+
using Flux: length, reshape, train!, throttle, @epochs
22
using OperatorLearning, Flux, MAT
33

44
device = gpu;
@@ -74,10 +74,12 @@ parameters = params(model)
7474
loss(x,y) = Flux.Losses.mse(model(x),y)
7575

7676
# Define a callback function that gives some output during training
77-
evalcb() = @show(loss(x,y))
77+
evalcb() = @show(loss(xtest,ytest))
78+
# Print the callback only every 5 seconds,
79+
throttled_cb = throttle(evalcb, 5)
7880

7981
# Do the training loop
80-
Flux.@epochs 500 train!(loss, parameters, train_loader, opt, cb = evalcb)
82+
Flux.@epochs 500 train!(loss, parameters, train_loader, opt, cb = throttled_cb)
8183

8284
# Accuracy metrics
8385
val_loader = Flux.Data.DataLoader((xtest, ytest), batchsize=1, shuffle=false) |> device
@@ -86,4 +88,4 @@ loss = 0.0 |> device
8688
for (x,y) in val_loader
8789
= model(x)
8890
loss += Flux.Losses.mse(ŷ,y)
89-
end
91+
end

src/DeepONet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Strictly speaking, DeepONet does not imply either of the branch or trunk net to
3636
Consider a transient 1D advection problem ∂ₜu + u ⋅ ∇u = 0, with an IC u(x,0) = g(x).
3737
We are given several (b = 200) instances of the IC, discretized at 50 points each and want to query the solution for 100 different locations and times [0;1].
3838
39-
That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100].
39+
That makes the branch input of shape [50 x 200] and the trunk input of shape [2 x 100]. So the input for the branch net is 50 and 100 for the trunk net.
4040
4141
# Usage
4242
@@ -101,7 +101,7 @@ Flux.@functor DeepONet
101101
x is the input function, evaluated at m locations (or m x b in case of batches)
102102
y is the array of sensors, i.e. the variables of the output function
103103
with shape (N x n) - N different variables with each n evaluation points =#
104-
function (a::DeepONet)(x::AbstractVecOrMat, y::AbstractVecOrMat)
104+
function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
105105
# Assign the parameters
106106
branch, trunk = a.branch_net, a.trunk_net
107107

0 commit comments

Comments
 (0)