Skip to content

13 proper handling of batches #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b5b6ed3
replace OMEinsum with batched_mul
Nov 2, 2021
62d5d8e
🚮 get rid of bias params
Nov 2, 2021
1b3986e
remove old code
pzimbrod Nov 2, 2021
8cbd84c
remove batch labels in burgers script
pzimbrod Nov 2, 2021
d4421e1
add custom batched_mul! methods
Nov 4, 2021
4303e7b
fix CUDA detection for FFT plan
Nov 4, 2021
f764a35
add CUDA implementation of batched_mul!
pzimbrod Nov 4, 2021
096bab9
add proper (but slow) CUDA version of batched_mul!
pzimbrod Nov 4, 2021
dbc1964
stick with non-planned FFT for now
Nov 5, 2021
ccb0e43
extend training script
Nov 5, 2021
6aa630b
pre-allocate fourier path
Nov 5, 2021
b85e693
make only fourier modes trainable
Nov 8, 2021
d66d0b9
update test syntax
Nov 8, 2021
353dcfa
correct trainable parameter syntax
pzimbrod Nov 8, 2021
f4deb17
correct trainable parameter syntax
pzimbrod Nov 8, 2021
a925df7
👨‍🏫 introduce strict typing for struct
pzimbrod Nov 8, 2021
adcbd4c
use the allocated arrays
pzimbrod Nov 8, 2021
269fd9b
revert dot operations as this breaks CUDA support
pzimbrod Nov 8, 2021
6a9bbcb
some cleanups, fix # of output dims in burgers.jl
Nov 9, 2021
1383209
use allocation in batch multiplication, fix gpu transfer of arrays
Nov 9, 2021
59ba985
fix test dim of Wl
Nov 9, 2021
6e0c24d
swap batch dim to ensure compatibility with Flux.DataLoader
Nov 9, 2021
1b84fc1
fix bias creation
Nov 9, 2021
47d1798
fix params assignment of bias
Nov 9, 2021
6deedc1
fix params syntax error
Nov 9, 2021
f690d52
re-introduce OMEinsum, get rid of batch argument
Nov 9, 2021
75c095e
fix Wl dim in test & update constructor in burgers
Nov 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
PkgTemplates = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -19,6 +20,5 @@ CUDA = "3"
FFTW = "1"
Flux = "0.12"
MAT = "0.10"
OMEinsum = "0.4, 0.6"
PkgTemplates = "0.7"
Revise = "3"
2 changes: 1 addition & 1 deletion benchmarks/benchFourierLayer.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Stolen from Flux as well

for n in [2, 20, 200, 2000]
x = randn(Float32, n, 2000, n)
x = randn(Float32, n, 2000, 100)
model = FourierLayer(n, n, 2000, 100, 16)
println("CPU n=$n")
run_benchmark(model, x, cuda=false)
Expand Down
104 changes: 52 additions & 52 deletions src/FourierLayer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ The output though only contains the relevant Fourier modes with the rest padded
in the last axis as a result of the filtering.

The input `x` should be a 3D tensor of shape
(num parameters (`in`) x batch size (`batch`) x num grid points (`grid`))
(num parameters (`in`) x num grid points (`grid`) x batch size (`batch`))
The output `y` will be a 3D tensor of shape
(`out` x batch size (`batch`) x num grid points (`grid`))
(`out` x num grid points (`grid`) x batch size (`batch`))

You can specify biases for the paths as you like, though the convolutional path is
originally not intended to perform an affine transformation.
Expand All @@ -23,35 +23,38 @@ originally not intended to perform an affine transformation.
Say you're considering a 1D diffusion problem on a 64 point grid. The input is comprised
of the grid points as well as the IC at this point.
The data consists of 200 instances of the solution.
So the input takes the dimension `2 x 200 x 64`.
So the input takes the dimension `2 x 64 x 200`.
The output would be the diffused variable at a later time, which makes the output of the form
`2 x 200 x 64` as well.
"""
struct FourierLayer{F, Mf<:AbstractArray, Ml<:AbstractArray, Bf<:AbstractArray,
Bl<:AbstractArray, fplan, ifplan,
Modes<:Int}
weight_f::Mf
weight_l::Ml
bias_f::Bf
bias_l::Bl
𝔉::fplan
i𝔉::ifplan
struct FourierLayer{F,Tc<:Complex{<:AbstractFloat},Tr<:AbstractFloat,Bf,Bl}
# F: Activation, Tc/Tr: Complex/Real eltype
Wf::AbstractArray{Tc,3}
Wl::AbstractMatrix{Tr}
grid::Int
σ::F
λ::Modes
λ::Int
bf::Bf
bl::Bl
# Constructor for the entire fourier layer
function FourierLayer(Wf::Mf, Wl::Ml, bf::Bf, bl::Bl, 𝔉::fplan, i𝔉::ifplan,
σ::F = identity, λ::Modes = 12) where {Mf<:AbstractArray, Ml<:AbstractArray,
Bf<:AbstractArray, Bl<:AbstractArray, fplan,
ifplan, F, Modes<:Int}
new{F,Mf,Ml,Bf,Bl,fplan,ifplan,Modes}(Wf, Wl, bf, bl, 𝔉, i𝔉, σ, λ)
function FourierLayer(
Wf::AbstractArray{Tc,3}, Wl::AbstractMatrix{Tr},
grid::Int, σ::F = identity,
λ::Int = 12, bf = true, bl = true) where
{F,Tc<:Complex{<:AbstractFloat},Tr<:AbstractFloat}

# create the biases with one singleton dimension
bf = Flux.create_bias(Wf, bf, 1, size(Wf,2), size(Wf,3))
bl = Flux.create_bias(Wl, bl, 1, size(Wl,1), grid)
new{F,Tc,Tr,typeof(bf),typeof(bl)}(Wf, Wl, grid, σ, λ, bf, bl)
end
end

# Declare the function that assigns Weights and biases to the layer
# `in` and `out` refer to the dimensionality of the number of parameters
# `modes` specifies the number of modes not to be filtered out
# `grid` specifies the number of grid points in the data
function FourierLayer(in::Integer, out::Integer, batch::Integer, grid::Integer, modes = 12,
function FourierLayer(in::Integer, out::Integer, grid::Integer, modes = 12,
σ = identity; initf = cglorot_uniform, initl = Flux.glorot_uniform,
bias_fourier=true, bias_linear=true)

Expand All @@ -66,67 +69,64 @@ function FourierLayer(in::Integer, out::Integer, batch::Integer, grid::Integer,
# Initialize Linear weight matrix
Wl = initl(out, in)

bf = Flux.create_bias(Wf, bias_fourier, out, batch, floor(Int, grid/2 + 1))
bl = Flux.create_bias(Wl, bias_linear, out, batch, grid)
# Pass the bias bools
bf = bias_fourier
bl = bias_linear

# Pass the modes for output
λ = modes
# Pre-allocate the interim arrays for the forward pass

# Create linear operators for the FFT and IFFT for efficiency
# So that it has to be only pre-allocated once
# First, an ugly workaround: FFTW.jl passes keywords that cuFFT complains about when the
# constructor is wrapped with |> gpu. Instead, you have to pass a CuArray as input to plan_rfft
# Ugh.
template𝔉 = Flux.use_cuda[] != Nothing ? Array{Float32}(undef,in,batch,grid) :
CuArray{Float32}(undef,in,batch,grid)
templatei𝔉 = Flux.use_cuda[] != Nothing ? Array{Complex{Float32}}(undef,out,batch,floor(Int, grid/2 + 1)) :
CuArray{Complex{Float32}}(undef,out,batch,floor(Int, grid/2 + 1))

𝔉 = plan_rfft(template𝔉,3)
i𝔉 = plan_irfft(templatei𝔉,grid, 3)

return FourierLayer(Wf, Wl, bf, bl, 𝔉, i𝔉, σ, λ)
return FourierLayer(Wf, Wl, grid, σ, λ, bf, bl)
end

Flux.@functor FourierLayer
# Only train the weight array with non-zero modes
Flux.@functor FourierLayer
Flux.trainable(a::FourierLayer) = (a.Wf[:,:,1:a.λ], a.Wl,
typeof(a.bf) != Flux.Zeros ? a.bf[:,:,1:a.λ] : nothing,
typeof(a.bl) != Flux.Zeros ? a.bl : nothing)

# The actual layer that does stuff
function (a::FourierLayer)(x::AbstractArray)
# Assign the parameters
Wf, Wl, bf, bl, σ, 𝔉, i𝔉 = a.weight_f, a.weight_l, a.bias_f, a.bias_l, a.σ, a.𝔉, a.i𝔉
Wf, Wl, bf, bl, σ, = a.Wf, a.Wl, a.bf, a.bl, a.σ
# Do a permutation: DataLoader requires batch to be the last dim
# for the rest, it's more convenient to have it in the first one
xp = permutedims(x, [3,1,2])

# The linear path
# x -> Wl
@ein linear[dim_out, batchsize, dim_grid] := Wl[dim_out, dim_in] *
x[dim_in, batchsize, dim_grid]
linear += bl
# linear .= batched_mul!(linear, Wl, x) .+ bl
@ein linear[batch, out, grid] := Wl[out, in] * xp[batch, in, grid]
linear .+ bl

# The convolution path
# x -> 𝔉 -> Wf -> i𝔉
# Do the Fourier transform (FFT) along the last axis of the input
fourier = 𝔉 * x
# Do the Fourier transform (FFT) along the grid dimension of the input and
# Multiply the weight matrix with the input using batched multiplication
# We need to permute the input to (channel,batch,grid), otherwise batching won't work
# 𝔉 .= batched_mul!(𝔉, Wf, rfft(permutedims(x, [1,3,2]),3)) .+ bf
@ein 𝔉[batch, out, grid] := Wf[in, out, grid] * rfft(xp, 3)[batch, in, grid]
𝔉 .+ bf

# Multiply the weight matrix with the input using the Einstein convention
@ein fourier[dim_out, batchsize, dim_grid] := Wf[dim_in, dim_out, dim_grid] *
fourier[dim_in, batchsize, dim_grid]
fourier += bf
# Do the inverse transform
fourier = i𝔉 * fourier
# We need to permute back to match the shape of the linear path
#i𝔉 = permutedims(irfft(𝔉, size(x,2), 3), [1,3,2])
i𝔉 = irfft(𝔉, size(xp,3),3)

# Return the activated sum
return σ.(linear + fourier)
return permutedims(σ.(linear + i𝔉), [2,3,1])
end

# Overload function to deal with higher-dimensional input arrays
#(a::FourierLayer)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)

# Print nicely
function Base.show(io::IO, l::FourierLayer)
print(io, "FourierLayer with\nConvolution path: (", size(l.weight_f, 2), ", ",
size(l.weight_f, 1), ", ", size(l.weight_f, 3))
print(io, "FourierLayer with\nConvolution path: (", size(l.Wf, 2), ", ",
size(l.Wf, 1), ", ", size(l.Wf, 3))
print(io, ")\n")
print(io, "Linear path: (", size(l.weight_l, 2), ", ", size(l.weight_l, 1), ", ",
size(l.weight_l, 3))
print(io, "Linear path: (", size(l.Wl, 2), ", ", size(l.Wl, 1))
print(io, ")\n")
print(io, "Fourier modes: ", l.λ)
print(io, "\n")
Expand Down
2 changes: 2 additions & 0 deletions src/NeuralOperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Base: Integer, ident_cmp, Float32
using CUDA
using Flux
using FFTW
using FFTW: assert_applicable, unsafe_execute!, FORWARD, BACKWARD, rFFTWPlan
using Random
using Random: AbstractRNG
using Flux: nfan, glorot_uniform, batch
Expand All @@ -13,5 +14,6 @@ export FourierLayer

include("FourierLayer.jl")
include("ComplexWeights.jl")
include("batched.jl")

end # module
39 changes: 39 additions & 0 deletions src/batched.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using CUDA: DenseCuArray
using CUDA.CUFFT: CuFFTPlan

"""
Function overloads for `batched_mul` provided by `NNlib` so that FFTW Plans can be handled.

For `mul!`, `FFTW.jl` already provides an implementation. However, we need to loop over the batch dimension.
"""

# Extension for CPU Arrays
for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
# Note: use $FORWARD and $BACKWARD below because of issue #9775
@eval begin
function NNlib.batched_mul!(y::StridedArray{$Tc}, p::rFFTWPlan{$Tr,$FORWARD}, x::StridedArray{$Tr})
assert_applicable(p, x[:,:,1], y[:,:,1]) # no need to check every batch dim
@inbounds for k ∈ 1:size(y,3)
@views unsafe_execute!(p, x[:,:,k], y[:,:,k])
end
return y
end
function NNlib.batched_mul!(y::StridedArray{$Tr}, p::rFFTWPlan{$Tc,$BACKWARD}, x::StridedArray{$Tc})
assert_applicable(p, x[:,:,1], y[:,:,1]) # no need to check every batch dim
@inbounds for k ∈ 1:size(y,3)
@views unsafe_execute!(p, x[:,:,k], y[:,:,k])
end
return y
end
end
end

# Methods for GPU Arrays, borrowed from CUDA.jl -> fft.jl #490
function NNlib.batched_mul!(y::DenseCuArray{Ty}, p::CuFFTPlan{T,K,false}, x::DenseCuArray{T}
) where {Ty,T,K}
CUFFT.assert_applicable(p, x[:,:,1], y[:,:,1])
@inbounds for k ∈ 1:size(y,3)
@views CUFFT.unsafe_execute!(p, x[:,:,k] ,y[:,:,k])
end
return y
end
46 changes: 26 additions & 20 deletions test/burgers.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
using Flux: length, reshape
using Flux: length, reshape, train!, @epochs
using NeuralOperator, Flux, MAT

device = gpu;

# Read the data from MAT file and store it in a dict
vars = matread("burgers_data_R10.mat")
vars = matread("burgers_data_R10.mat") |> device

# For trial purposes, we might want to train with different resolutions
# So we sample only every n-th element
subsample = 2^3;

# create the x training array, according to our desired grid size
xtrain = vars["a"][1:1000, 1:subsample:end];
xtrain = vars["a"][1:1000, 1:subsample:end] |> device;
# create the x test array
xtest = vars["a"][end-99:end, 1:subsample:end];
xtest = vars["a"][end-99:end, 1:subsample:end] |> device;

# Create the y training array
ytrain = vars["u"][1:1000, 1:subsample:end];
ytrain = vars["u"][1:1000, 1:subsample:end] |> device;
# Create the y test array
ytest = vars["u"][end-99:end, 1:subsample:end];
ytest = vars["u"][end-99:end, 1:subsample:end] |> device;

# The data is missing grid data, so we create it
# `collect` converts data type `range` into an array
grid = collect(range(0, 1, length=length(xtrain[1,:])))
grid = collect(range(0, 1, length=length(xtrain[1,:]))) |> device

# Merge the created grid with the data
# Output has the dims: batch x grid points x 2 (a(x), x)
Expand All @@ -29,40 +31,41 @@ grid = collect(range(0, 1, length=length(xtrain[1,:])))
# and concatenate them along the newly created 3rd dim
xtrain = cat(reshape(xtrain,(1000,1024,1)),
reshape(repeat(grid,1000),(1000,1024,1));
dims=3)
dims=3) |> device
ytrain = cat(reshape(ytrain,(1000,1024,1)),
reshape(repeat(grid,1000),(1000,1024,1));
dims=3)
dims=3) |> device
# Same treatment with the test data
xtest = cat(reshape(xtest,(100,1024,1)),
reshape(repeat(grid,100),(100,1024,1));
dims=3)
dims=3) |> device
ytest = cat(reshape(ytest,(100,1024,1)),
reshape(repeat(grid,100),(100,1024,1));
dims=3)
dims=3) |> device

# Our net wants the input in the form (2,batch,grid), though,
# Our net wants the input in the form (2,grid,batch), though,
# So we permute
xtrain, xtest = permutedims(xtrain,(3,1,2)), permutedims(xtest,(3,1,2))
ytrain, ytest = permutedims(ytrain,(3,1,2)), permutedims(ytest,(3,1,2))
xtrain, xtest = permutedims(xtrain,(3,2,1)), permutedims(xtest,(3,2,1)) |> device
ytrain, ytest = permutedims(ytrain,(3,2,1)), permutedims(ytest,(3,2,1)) |> device

# Pass the data to the Flux DataLoader and give it a batch of 20
train_loader = Flux.Data.DataLoader((data=xtrain, label=ytrain), batchsize=20)
test_loader = Flux.Data.DataLoader((data=xtest, label=ytest), batchsize=20)
train_loader = Flux.Data.DataLoader((xtrain, ytrain), batchsize=20, shuffle=true) |> device
test_loader = Flux.Data.DataLoader((xtest, ytest), batchsize=20, shuffle=true) |> device

# Set up the Fourier Layer
# 128 in- and outputs, batch size 20 as given above, grid size 1024
# 16 modes to keep, σ activation on the gpu
layer = FourierLayer(128,128,20,1024,16,σ)
layer = FourierLayer(128,128,1024,16,gelu,bias_fourier=false) |> device

# The whole architecture
# linear transform into the latent space, 4 Fourier Layers,
# then transform it back
model = Chain(Dense(2,128;bias=false), layer, layer, layer, layer,
Dense(128,2;bias=false))
Dense(128,2;bias=false)) |> device

# We use the ADAM optimizer for training
opt = ADAM()
learning_rate = 0.001
opt = ADAM(learning_rate)

# Specify the model parameters
parameters = params(model)
Expand All @@ -71,4 +74,7 @@ parameters = params(model)
loss(x,y) = Flux.Losses.mse(model(x),y)

# Define a callback function that gives some output during training
evalcb() = @show(loss(x,y))
evalcb() = @show(loss(x,y))

# Do the training loop
Flux.@epochs 500 train!(loss, parameters, train_loader, opt, cb = evalcb)
12 changes: 6 additions & 6 deletions test/fourierlayer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ using Test, Random, Flux

@testset "FourierLayer" begin
# Test the proper construction
@test size(FourierLayer(128, 64, 200, 100, 20).weight_f) == (128, 64, 51)
@test size(FourierLayer(128, 64, 200, 100, 20).weight_l) == (64, 128)
@test size(FourierLayer(128, 64, 100, 20).Wf) == (128, 64, 51)
@test size(FourierLayer(128, 64, 100, 20).Wl) == (64, 128)
#@test size(FourierLayer(10, 100).bias_f) == (51,)
#@test size(FourierLayer(10, 100).bias_l) == (100,)

# Accept only Int as architecture parameters
@test_throws MethodError FourierLayer(128.5, 64, 200, 100, 20)
@test_throws MethodError FourierLayer(128.5, 64, 200, 100, 20, tanh)
@test_throws AssertionError FourierLayer(100, 100, 100, 100, 60, σ)
@test_throws AssertionError FourierLayer(100, 100, 100, 100, 60)
@test_throws MethodError FourierLayer(128.5, 64, 100, 20)
@test_throws MethodError FourierLayer(128.5, 64, 100, 20, tanh)
@test_throws AssertionError FourierLayer(100, 100, 100, 60, σ)
@test_throws AssertionError FourierLayer(100, 100, 100, 60)
end