Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

Commit

Permalink
fix DeepONet for CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Mar 8, 2022
1 parent cf8f4bd commit 65e0749
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 56 deletions.
40 changes: 21 additions & 19 deletions example/Burgers/src/Burgers_deeponet.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
function train_don()
# if has_cuda()
# @info "CUDA is on"
# device = gpu
# CUDA.allowscalar(false)
# else
function train_don(; n=300, cuda=true, learning_rate=0.001, epochs=400)
if cuda && has_cuda()
@info "Training on GPU"
device = gpu
else
@info "Training on CPU"
device = cpu
# end
end

x, y = get_data_don(n=300)
xtrain = x[1:280, :]' |> device
xval = x[end-19:end, :]' |> device
x, y = get_data_don(n=n)

xtrain = x[1:280, :]'
ytrain = y[1:280, :]

ytrain = y[1:280, :] |> device
xval = x[end-19:end, :]' |> device
yval = y[end-19:end, :] |> device

grid = collect(range(0, 1, length=1024))' |> device
grid = collect(range(0, 1, length=1024)') |> device

learning_rate = 0.001
opt = ADAM(learning_rate)

m = DeepONet((1024,1024,1024),(1,1024,1024),gelu,gelu)
loss(xtrain,ytrain,sensor) = Flux.Losses.mse(m(xtrain,sensor),ytrain)
evalcb() = @show(loss(xval,yval,grid))
m = DeepONet((1024,1024,1024), (1,1024,1024), gelu, gelu) |> device

loss(X, y, sensor) = Flux.Losses.mse(m(X, sensor), y)
evalcb() = @show(loss(xval, yval, grid))

Flux.@epochs 400 Flux.train!(loss, params(m), [(xtrain,ytrain,grid)], opt, cb = evalcb)
= m(xval, grid)
data = [(xtrain, ytrain, grid)] |> device
Flux.@epochs epochs Flux.train!(loss, params(m), data, opt, cb=evalcb)
= m(xval |> device, grid |> device)

diffvec = vec(abs.((yval .- ỹ)))
diffvec = vec(abs.(cpu(yval) .- cpu(ỹ)))
mean_diff = sum(diffvec)/length(diffvec)
return mean_diff
end
2 changes: 1 addition & 1 deletion src/DeepONet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ function (a::DeepONet)(x::AbstractArray, y::AbstractVecOrMat)
However, we perform the transformations by the NNs always in the first dim
so we need to adjust (i.e. transpose) one of the inputs,
which we do on the branch input here =#
return Array(branch(x)') * trunk(y)
return branch(x)' * trunk(y)
end

# Sensors stay the same and shouldn't be batched
Expand Down
18 changes: 18 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testset "CUDA" begin
@testset "DeepONet" begin
batch_size = 2
a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755,
0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651,
0.81943734, 0.81737952, 0.8152405, 0.81302771]
a = repeat(a, outer=(1, batch_size)) |> gpu
sensors = collect(range(0, 1, length=16)')
sensors = repeat(sensors, outer=(batch_size, 1)) |> gpu
model = DeepONet((16, 22, 30), (2, 16, 24, 30), σ, tanh;
init_branch=Flux.glorot_normal, bias_trunk=false) |> gpu
y = model(a, sensors)
@test size(y) == (batch_size, 16)

mgrad = Flux.Zygote.gradient(() -> sum(model(a, sensors)), Flux.params(model))
@test length(mgrad.grads) == 9
end
end
57 changes: 25 additions & 32 deletions test/deeponet.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,34 @@
using Test, Flux

@testset "DeepONet" begin
@testset "dimensions" begin
# Test the proper construction
@testset "proper construction" begin
deeponet = DeepONet((32,64,72), (24,48,72), σ, tanh)
# Branch net
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).branch_net.layers[end].weight) == (72,64)
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).branch_net.layers[end].bias) == (72,)
@test size(deeponet.branch_net.layers[end].weight) == (72,64)
@test size(deeponet.branch_net.layers[end].bias) == (72,)
# Trunk net
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).trunk_net.layers[end].weight) == (72,48)
@test size(DeepONet((32,64,72), (24,48,72), σ, tanh).trunk_net.layers[end].bias) == (72,)
@test size(deeponet.trunk_net.layers[end].weight) == (72,48)
@test size(deeponet.trunk_net.layers[end].bias) == (72,)
end

# Accept only Int as architecture parameters
@test_throws MethodError DeepONet((32.5,64,72), (24,48,72), σ, tanh)
@test_throws MethodError DeepONet((32,64,72), (24.1,48,72))
end

#Just the first 16 datapoints from the Burgers' equation dataset
a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755, 0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651, 0.81943734, 0.81737952, 0.8152405, 0.81302771]
sensors = collect(range(0, 1, length=16))'

model = DeepONet((16, 22, 30), (1, 16, 24, 30), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)

model(a,sensors)

#forward pass
@test size(model(a, sensors)) == (1, 16)

mgrad = Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)

#gradients
@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[1])
@test !iszero(Flux.Zygote.gradient((x,p)->sum(model(x,p)),a,sensors)[2])

#Output size of branch and trunk subnets should be same
branch = Chain(Dense(16, 22), Dense(22, 30))
trunk = Chain(Dense(1, 16), Dense(16, 24), Dense(24, 32))
m = DeepONet(branch, trunk)
@test_throws AssertionError DeepONet((32,64,70), (24,48,72), σ, tanh)
@test_throws DimensionMismatch m(a, sensors)
# Just the first 16 datapoints from the Burgers' equation dataset
a = [0.83541104, 0.83479851, 0.83404712, 0.83315711, 0.83212979, 0.83096755,
0.82967374, 0.82825263, 0.82670928, 0.82504949, 0.82327962, 0.82140651,
0.81943734, 0.81737952, 0.8152405, 0.81302771]
sensors = collect(range(0, 1, length=16)')
model = DeepONet((16, 22, 30), (1, 16, 24, 30), σ, tanh; init_branch=Flux.glorot_normal, bias_trunk=false)
y = model(a, sensors)
@test size(y) == (1, 16)

mgrad = Flux.Zygote.gradient(() -> sum(model(a, sensors)), Flux.params(model))
@test length(mgrad.grads) == 7

# Output size of branch and trunk subnets should be same
branch = Chain(Dense(16, 22), Dense(22, 30))
trunk = Chain(Dense(1, 16), Dense(16, 24), Dense(24, 32))
m = DeepONet(branch, trunk)
@test_throws AssertionError DeepONet((32,64,70), (24,48,72), σ, tanh)
@test_throws DimensionMismatch m(a, sensors)
end
27 changes: 23 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
using NeuralOperators
using Test
using Flux
using CUDA

CUDA.allowscalar(false)

cuda_tests = [
"cuda",
]

tests = [
"Transform/Transform",
"operator_kernel",
"model",
"deeponet",
]

if CUDA.functional()
append!(tests, cuda_tests)
else
@warn "CUDA unavailable, not testing GPU support"
end

@testset "NeuralOperators.jl" begin
include("Transform/Transform.jl")
include("operator_kernel.jl")
include("model.jl")
include("deeponet.jl")
for t in tests
include("$(t).jl")
end
end

#=
Expand Down

0 comments on commit 65e0749

Please sign in to comment.