diff --git a/Project.toml b/Project.toml index 217dcb7..069f047 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" PkgTemplates = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -19,6 +20,5 @@ CUDA = "3" FFTW = "1" Flux = "0.12" MAT = "0.10" -OMEinsum = "0.4, 0.6" PkgTemplates = "0.7" Revise = "3" diff --git a/benchmarks/benchFourierLayer.jl b/benchmarks/benchFourierLayer.jl index 6460026..71e6e31 100644 --- a/benchmarks/benchFourierLayer.jl +++ b/benchmarks/benchFourierLayer.jl @@ -1,7 +1,7 @@ # Stolen from Flux as well for n in [2, 20, 200, 2000] - x = randn(Float32, n, 2000, n) + x = randn(Float32, n, 2000, 100) model = FourierLayer(n, n, 2000, 100, 16) println("CPU n=$n") run_benchmark(model, x, cuda=false) diff --git a/src/FourierLayer.jl b/src/FourierLayer.jl index 6606b3d..311b9b9 100644 --- a/src/FourierLayer.jl +++ b/src/FourierLayer.jl @@ -12,9 +12,9 @@ The output though only contains the relevant Fourier modes with the rest padded in the last axis as a result of the filtering. The input `x` should be a 3D tensor of shape -(num parameters (`in`) x batch size (`batch`) x num grid points (`grid`)) +(num parameters (`in`) x num grid points (`grid`) x batch size (`batch`)) The output `y` will be a 3D tensor of shape -(`out` x batch size (`batch`) x num grid points (`grid`)) +(`out` x num grid points (`grid`) x batch size (`batch`)) You can specify biases for the paths as you like, though the convolutional path is originally not intended to perform an affine transformation. @@ -23,27 +23,30 @@ originally not intended to perform an affine transformation. Say you're considering a 1D diffusion problem on a 64 point grid. The input is comprised of the grid points as well as the IC at this point. The data consists of 200 instances of the solution. -So the input takes the dimension `2 x 200 x 64`. +So the input takes the dimension `2 x 64 x 200`. The output would be the diffused variable at a later time, which makes the output of the form `2 x 200 x 64` as well. """ -struct FourierLayer{F, Mf<:AbstractArray, Ml<:AbstractArray, Bf<:AbstractArray, - Bl<:AbstractArray, fplan, ifplan, - Modes<:Int} - weight_f::Mf - weight_l::Ml - bias_f::Bf - bias_l::Bl - 𝔉::fplan - i𝔉::ifplan +struct FourierLayer{F,Tc<:Complex{<:AbstractFloat},Tr<:AbstractFloat,Bf,Bl} + # F: Activation, Tc/Tr: Complex/Real eltype + Wf::AbstractArray{Tc,3} + Wl::AbstractMatrix{Tr} + grid::Int σ::F - λ::Modes + λ::Int + bf::Bf + bl::Bl # Constructor for the entire fourier layer - function FourierLayer(Wf::Mf, Wl::Ml, bf::Bf, bl::Bl, 𝔉::fplan, i𝔉::ifplan, - σ::F = identity, λ::Modes = 12) where {Mf<:AbstractArray, Ml<:AbstractArray, - Bf<:AbstractArray, Bl<:AbstractArray, fplan, - ifplan, F, Modes<:Int} - new{F,Mf,Ml,Bf,Bl,fplan,ifplan,Modes}(Wf, Wl, bf, bl, 𝔉, i𝔉, σ, λ) + function FourierLayer( + Wf::AbstractArray{Tc,3}, Wl::AbstractMatrix{Tr}, + grid::Int, σ::F = identity, + λ::Int = 12, bf = true, bl = true) where + {F,Tc<:Complex{<:AbstractFloat},Tr<:AbstractFloat} + + # create the biases with one singleton dimension + bf = Flux.create_bias(Wf, bf, 1, size(Wf,2), size(Wf,3)) + bl = Flux.create_bias(Wl, bl, 1, size(Wl,1), grid) + new{F,Tc,Tr,typeof(bf),typeof(bl)}(Wf, Wl, grid, σ, λ, bf, bl) end end @@ -51,7 +54,7 @@ end # `in` and `out` refer to the dimensionality of the number of parameters # `modes` specifies the number of modes not to be filtered out # `grid` specifies the number of grid points in the data -function FourierLayer(in::Integer, out::Integer, batch::Integer, grid::Integer, modes = 12, +function FourierLayer(in::Integer, out::Integer, grid::Integer, modes = 12, σ = identity; initf = cglorot_uniform, initl = Flux.glorot_uniform, bias_fourier=true, bias_linear=true) @@ -66,55 +69,53 @@ function FourierLayer(in::Integer, out::Integer, batch::Integer, grid::Integer, # Initialize Linear weight matrix Wl = initl(out, in) - bf = Flux.create_bias(Wf, bias_fourier, out, batch, floor(Int, grid/2 + 1)) - bl = Flux.create_bias(Wl, bias_linear, out, batch, grid) + # Pass the bias bools + bf = bias_fourier + bl = bias_linear # Pass the modes for output λ = modes + # Pre-allocate the interim arrays for the forward pass - # Create linear operators for the FFT and IFFT for efficiency - # So that it has to be only pre-allocated once - # First, an ugly workaround: FFTW.jl passes keywords that cuFFT complains about when the - # constructor is wrapped with |> gpu. Instead, you have to pass a CuArray as input to plan_rfft - # Ugh. - template𝔉 = Flux.use_cuda[] != Nothing ? Array{Float32}(undef,in,batch,grid) : - CuArray{Float32}(undef,in,batch,grid) - templatei𝔉 = Flux.use_cuda[] != Nothing ? Array{Complex{Float32}}(undef,out,batch,floor(Int, grid/2 + 1)) : - CuArray{Complex{Float32}}(undef,out,batch,floor(Int, grid/2 + 1)) - - 𝔉 = plan_rfft(template𝔉,3) - i𝔉 = plan_irfft(templatei𝔉,grid, 3) - - return FourierLayer(Wf, Wl, bf, bl, 𝔉, i𝔉, σ, λ) + return FourierLayer(Wf, Wl, grid, σ, λ, bf, bl) end -Flux.@functor FourierLayer +# Only train the weight array with non-zero modes +Flux.@functor FourierLayer +Flux.trainable(a::FourierLayer) = (a.Wf[:,:,1:a.λ], a.Wl, + typeof(a.bf) != Flux.Zeros ? a.bf[:,:,1:a.λ] : nothing, + typeof(a.bl) != Flux.Zeros ? a.bl : nothing) # The actual layer that does stuff function (a::FourierLayer)(x::AbstractArray) # Assign the parameters - Wf, Wl, bf, bl, σ, 𝔉, i𝔉 = a.weight_f, a.weight_l, a.bias_f, a.bias_l, a.σ, a.𝔉, a.i𝔉 + Wf, Wl, bf, bl, σ, = a.Wf, a.Wl, a.bf, a.bl, a.σ + # Do a permutation: DataLoader requires batch to be the last dim + # for the rest, it's more convenient to have it in the first one + xp = permutedims(x, [3,1,2]) # The linear path # x -> Wl - @ein linear[dim_out, batchsize, dim_grid] := Wl[dim_out, dim_in] * - x[dim_in, batchsize, dim_grid] - linear += bl + # linear .= batched_mul!(linear, Wl, x) .+ bl + @ein linear[batch, out, grid] := Wl[out, in] * xp[batch, in, grid] + linear .+ bl # The convolution path # x -> 𝔉 -> Wf -> i𝔉 - # Do the Fourier transform (FFT) along the last axis of the input - fourier = 𝔉 * x + # Do the Fourier transform (FFT) along the grid dimension of the input and + # Multiply the weight matrix with the input using batched multiplication + # We need to permute the input to (channel,batch,grid), otherwise batching won't work + # 𝔉 .= batched_mul!(𝔉, Wf, rfft(permutedims(x, [1,3,2]),3)) .+ bf + @ein 𝔉[batch, out, grid] := Wf[in, out, grid] * rfft(xp, 3)[batch, in, grid] + 𝔉 .+ bf - # Multiply the weight matrix with the input using the Einstein convention - @ein fourier[dim_out, batchsize, dim_grid] := Wf[dim_in, dim_out, dim_grid] * - fourier[dim_in, batchsize, dim_grid] - fourier += bf # Do the inverse transform - fourier = i𝔉 * fourier + # We need to permute back to match the shape of the linear path + #i𝔉 = permutedims(irfft(𝔉, size(x,2), 3), [1,3,2]) + i𝔉 = irfft(𝔉, size(xp,3),3) # Return the activated sum - return σ.(linear + fourier) + return permutedims(σ.(linear + i𝔉), [2,3,1]) end # Overload function to deal with higher-dimensional input arrays @@ -122,11 +123,10 @@ end # Print nicely function Base.show(io::IO, l::FourierLayer) - print(io, "FourierLayer with\nConvolution path: (", size(l.weight_f, 2), ", ", - size(l.weight_f, 1), ", ", size(l.weight_f, 3)) + print(io, "FourierLayer with\nConvolution path: (", size(l.Wf, 2), ", ", + size(l.Wf, 1), ", ", size(l.Wf, 3)) print(io, ")\n") - print(io, "Linear path: (", size(l.weight_l, 2), ", ", size(l.weight_l, 1), ", ", - size(l.weight_l, 3)) + print(io, "Linear path: (", size(l.Wl, 2), ", ", size(l.Wl, 1)) print(io, ")\n") print(io, "Fourier modes: ", l.λ) print(io, "\n") diff --git a/src/NeuralOperator.jl b/src/NeuralOperator.jl index 0a5542f..26f1ee2 100644 --- a/src/NeuralOperator.jl +++ b/src/NeuralOperator.jl @@ -4,6 +4,7 @@ using Base: Integer, ident_cmp, Float32 using CUDA using Flux using FFTW +using FFTW: assert_applicable, unsafe_execute!, FORWARD, BACKWARD, rFFTWPlan using Random using Random: AbstractRNG using Flux: nfan, glorot_uniform, batch @@ -13,5 +14,6 @@ export FourierLayer include("FourierLayer.jl") include("ComplexWeights.jl") +include("batched.jl") end # module diff --git a/src/batched.jl b/src/batched.jl new file mode 100644 index 0000000..1aff86d --- /dev/null +++ b/src/batched.jl @@ -0,0 +1,39 @@ +using CUDA: DenseCuArray +using CUDA.CUFFT: CuFFTPlan + +""" +Function overloads for `batched_mul` provided by `NNlib` so that FFTW Plans can be handled. + +For `mul!`, `FFTW.jl` already provides an implementation. However, we need to loop over the batch dimension. +""" + +# Extension for CPU Arrays +for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) + # Note: use $FORWARD and $BACKWARD below because of issue #9775 + @eval begin + function NNlib.batched_mul!(y::StridedArray{$Tc}, p::rFFTWPlan{$Tr,$FORWARD}, x::StridedArray{$Tr}) + assert_applicable(p, x[:,:,1], y[:,:,1]) # no need to check every batch dim + @inbounds for k ∈ 1:size(y,3) + @views unsafe_execute!(p, x[:,:,k], y[:,:,k]) + end + return y + end + function NNlib.batched_mul!(y::StridedArray{$Tr}, p::rFFTWPlan{$Tc,$BACKWARD}, x::StridedArray{$Tc}) + assert_applicable(p, x[:,:,1], y[:,:,1]) # no need to check every batch dim + @inbounds for k ∈ 1:size(y,3) + @views unsafe_execute!(p, x[:,:,k], y[:,:,k]) + end + return y + end + end +end + +# Methods for GPU Arrays, borrowed from CUDA.jl -> fft.jl #490 +function NNlib.batched_mul!(y::DenseCuArray{Ty}, p::CuFFTPlan{T,K,false}, x::DenseCuArray{T} + ) where {Ty,T,K} + CUFFT.assert_applicable(p, x[:,:,1], y[:,:,1]) + @inbounds for k ∈ 1:size(y,3) + @views CUFFT.unsafe_execute!(p, x[:,:,k] ,y[:,:,k]) + end + return y +end diff --git a/test/burgers.jl b/test/burgers.jl index df531d8..f6c7794 100644 --- a/test/burgers.jl +++ b/test/burgers.jl @@ -1,26 +1,28 @@ -using Flux: length, reshape +using Flux: length, reshape, train!, @epochs using NeuralOperator, Flux, MAT +device = gpu; + # Read the data from MAT file and store it in a dict -vars = matread("burgers_data_R10.mat") +vars = matread("burgers_data_R10.mat") |> device # For trial purposes, we might want to train with different resolutions # So we sample only every n-th element subsample = 2^3; # create the x training array, according to our desired grid size -xtrain = vars["a"][1:1000, 1:subsample:end]; +xtrain = vars["a"][1:1000, 1:subsample:end] |> device; # create the x test array -xtest = vars["a"][end-99:end, 1:subsample:end]; +xtest = vars["a"][end-99:end, 1:subsample:end] |> device; # Create the y training array -ytrain = vars["u"][1:1000, 1:subsample:end]; +ytrain = vars["u"][1:1000, 1:subsample:end] |> device; # Create the y test array -ytest = vars["u"][end-99:end, 1:subsample:end]; +ytest = vars["u"][end-99:end, 1:subsample:end] |> device; # The data is missing grid data, so we create it # `collect` converts data type `range` into an array -grid = collect(range(0, 1, length=length(xtrain[1,:]))) +grid = collect(range(0, 1, length=length(xtrain[1,:]))) |> device # Merge the created grid with the data # Output has the dims: batch x grid points x 2 (a(x), x) @@ -29,40 +31,41 @@ grid = collect(range(0, 1, length=length(xtrain[1,:]))) # and concatenate them along the newly created 3rd dim xtrain = cat(reshape(xtrain,(1000,1024,1)), reshape(repeat(grid,1000),(1000,1024,1)); - dims=3) + dims=3) |> device ytrain = cat(reshape(ytrain,(1000,1024,1)), reshape(repeat(grid,1000),(1000,1024,1)); - dims=3) + dims=3) |> device # Same treatment with the test data xtest = cat(reshape(xtest,(100,1024,1)), reshape(repeat(grid,100),(100,1024,1)); - dims=3) + dims=3) |> device ytest = cat(reshape(ytest,(100,1024,1)), reshape(repeat(grid,100),(100,1024,1)); - dims=3) + dims=3) |> device -# Our net wants the input in the form (2,batch,grid), though, +# Our net wants the input in the form (2,grid,batch), though, # So we permute -xtrain, xtest = permutedims(xtrain,(3,1,2)), permutedims(xtest,(3,1,2)) -ytrain, ytest = permutedims(ytrain,(3,1,2)), permutedims(ytest,(3,1,2)) +xtrain, xtest = permutedims(xtrain,(3,2,1)), permutedims(xtest,(3,2,1)) |> device +ytrain, ytest = permutedims(ytrain,(3,2,1)), permutedims(ytest,(3,2,1)) |> device # Pass the data to the Flux DataLoader and give it a batch of 20 -train_loader = Flux.Data.DataLoader((data=xtrain, label=ytrain), batchsize=20) -test_loader = Flux.Data.DataLoader((data=xtest, label=ytest), batchsize=20) +train_loader = Flux.Data.DataLoader((xtrain, ytrain), batchsize=20, shuffle=true) |> device +test_loader = Flux.Data.DataLoader((xtest, ytest), batchsize=20, shuffle=true) |> device # Set up the Fourier Layer # 128 in- and outputs, batch size 20 as given above, grid size 1024 # 16 modes to keep, σ activation on the gpu -layer = FourierLayer(128,128,20,1024,16,σ) +layer = FourierLayer(128,128,1024,16,gelu,bias_fourier=false) |> device # The whole architecture # linear transform into the latent space, 4 Fourier Layers, # then transform it back model = Chain(Dense(2,128;bias=false), layer, layer, layer, layer, - Dense(128,2;bias=false)) + Dense(128,2;bias=false)) |> device # We use the ADAM optimizer for training -opt = ADAM() +learning_rate = 0.001 +opt = ADAM(learning_rate) # Specify the model parameters parameters = params(model) @@ -71,4 +74,7 @@ parameters = params(model) loss(x,y) = Flux.Losses.mse(model(x),y) # Define a callback function that gives some output during training -evalcb() = @show(loss(x,y)) \ No newline at end of file +evalcb() = @show(loss(x,y)) + +# Do the training loop +Flux.@epochs 500 train!(loss, parameters, train_loader, opt, cb = evalcb) \ No newline at end of file diff --git a/test/fourierlayer.jl b/test/fourierlayer.jl index 0b48159..6b1773e 100644 --- a/test/fourierlayer.jl +++ b/test/fourierlayer.jl @@ -2,14 +2,14 @@ using Test, Random, Flux @testset "FourierLayer" begin # Test the proper construction - @test size(FourierLayer(128, 64, 200, 100, 20).weight_f) == (128, 64, 51) - @test size(FourierLayer(128, 64, 200, 100, 20).weight_l) == (64, 128) + @test size(FourierLayer(128, 64, 100, 20).Wf) == (128, 64, 51) + @test size(FourierLayer(128, 64, 100, 20).Wl) == (64, 128) #@test size(FourierLayer(10, 100).bias_f) == (51,) #@test size(FourierLayer(10, 100).bias_l) == (100,) # Accept only Int as architecture parameters - @test_throws MethodError FourierLayer(128.5, 64, 200, 100, 20) - @test_throws MethodError FourierLayer(128.5, 64, 200, 100, 20, tanh) - @test_throws AssertionError FourierLayer(100, 100, 100, 100, 60, σ) - @test_throws AssertionError FourierLayer(100, 100, 100, 100, 60) + @test_throws MethodError FourierLayer(128.5, 64, 100, 20) + @test_throws MethodError FourierLayer(128.5, 64, 100, 20, tanh) + @test_throws AssertionError FourierLayer(100, 100, 100, 60, σ) + @test_throws AssertionError FourierLayer(100, 100, 100, 60) end \ No newline at end of file