Skip to content

Commit

Permalink
fix the train function, the output of loss is not itself
Browse files Browse the repository at this point in the history
  • Loading branch information
v1kko committed Sep 5, 2024
1 parent 6bfb21b commit 665ad2f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions simulations/NavierStokes_2D/scripts/NeuralClosure+SciML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ tstate = Lux.Training.TrainState(closure, θ, st, OptimizationOptimisers.Adam(0.
_, loss, stats, tstate = Lux.Training.single_train_step!(
Optimization.AutoZygote(), loss_lux_style, train_data, tstate)

tstate = train(closure, θ, st, dataloader, loss_lux_style;
train(closure, θ, st, dataloader, loss_lux_style;
nepochs = 100, ad_type = Optimization.AutoZygote(),
alg = OptimizationOptimisers.Adam(0.1), cpu = true)
# still a problem with the train function


# * A posteriori dataloader
# indeed the ioarrays are not useful here, what a bummer! We should come up with a format that would be useful for both a-priori and a-posteriori training.
Expand Down
4 changes: 2 additions & 2 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function train(dataloaders,
(; optstate, θ, callbackstate)
end

function train(model, ps, st, train_dataloader, loss;
function train(model, ps, st, train_dataloader, loss_function;
nepochs = 100, ad_type = Optimization.AutoZygote(),
alg = OptimizationOptimisers.Adam(0.1), cpu::Bool = false, kwargs...)
dev = cpu ? Lux.cpu_device() : Lux.gpu_device()
Expand All @@ -55,7 +55,7 @@ function train(model, ps, st, train_dataloader, loss;
#y = dev(y)
data = train_dataloader()
_, loss, stats, tstate = Lux.Training.single_train_step!(
ad_type, loss, data, tstate)
ad_type, loss_function, data, tstate)
end
loss, tstate
end

0 comments on commit 665ad2f

Please sign in to comment.