Skip to content

Commit

Permalink
Partial fix for a_posteriori lux style
Browse files Browse the repository at this point in the history
  • Loading branch information
v1kko committed Sep 9, 2024
1 parent 665ad2f commit ba0a3e2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
15 changes: 13 additions & 2 deletions simulations/NavierStokes_2D/scripts/NeuralClosure+SciML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@ function loss_posteriori(model, p, st, data)
return T(sum(abs2, y - pred[:, :, :, 1, 2:end]) / sum(abs2, y))
end

# let's define a loss that calculates correctly and in the Lux format
function loss_posteriori_lux_style(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = sum(abs2, ŷ - y) / sum(abs2, y)
return loss, st_, (; y_pred = ŷ)
end


# train a-posteriori: single data point
train_data_posteriori = dataloader_luisa()
optf = Optimization.OptimizationFunction(
Expand All @@ -316,10 +324,13 @@ result_posteriori = Optimization.solve(
)
θ_posteriori = result_posteriori.u

loss_posteriori_lux_style(closure, θ_posteriori, st, train_data_posteriori)


# try with Lux
tstate = Lux.Training.TrainState(closure, θ, st, OptimizationOptimisers.Adam(0.1))
tstate = Lux.Training.TrainState(closure, θ_posteriori, st, OptimizationOptimisers.Adam(0.1))
_, loss, stats, tstate = Lux.Training.single_train_step!(
Optimization.AutoZygote(), loss_posteriori, train_data_posteriori, tstate)
Optimization.AutoZygote(), loss_lux_style, train_data_posteriori, tstate)

Lux.Training.compute_gradients(
Optimization.AutoZygote(), loss_posteriori, train_data_posteriori, tstate)
Expand Down
3 changes: 2 additions & 1 deletion src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@ function train(model, ps, st, train_dataloader, loss_function;
alg = OptimizationOptimisers.Adam(0.1), cpu::Bool = false, kwargs...)
dev = cpu ? Lux.cpu_device() : Lux.gpu_device()
tstate = Lux.Training.TrainState(model, ps, st, alg)
loss::Float32 = 0 #NOP
for epoch in 1:nepochs
#(x, y) = train_dataloader()
#x = dev(x)
#y = dev(y)
data = train_dataloader()
_, loss, stats, tstate = Lux.Training.single_train_step!(
_, loss, _, tstate = Lux.Training.single_train_step!(
ad_type, loss_function, data, tstate)
end
loss, tstate
Expand Down

0 comments on commit ba0a3e2

Please sign in to comment.