-
Notifications
You must be signed in to change notification settings - Fork 0
/
stats.jl
75 lines (63 loc) · 2.74 KB
/
stats.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
module stats
using Plots
mutable struct LearningStats
trainAcc::Array{Float64,1}
trainLoss::Array{Float64,1}
valAcc::Array{Float64,1}
valLoss::Array{Float64,1}
testAcc::Array{Float64,1}
testLoss::Array{Float64,1}
bestValAcc::Float64
bestValLoss::Float64
end
LearningStats() = LearningStats([], [], [], [], [], [], -Inf64, Inf64)
mutable struct GANStats
cLoss::Array{Float64,1}
gLoss::Array{Float64,1}
end
GANStats() = GANStats([], [])
function plotLearningStats(stats::LearningStats, name::String, isClassification::Bool)
if isClassification
plt = plot(1:length(stats.valLoss), hcat(stats.trainLoss, stats.valLoss, stats.valAcc), label = ["Train Loss", "Val. Loss", "Val. Accuracy"], xlabel = "Epochs", ylabel = "Loss")
else
plt = plot(1:length(stats.train), hcat(stats.trainLoss, stats.valLoss), label = ["Train Loss", "Val. Loss"], xlabel = "Epochs", ylabel = "Loss")
end
savefig(plt, "$(name).png");
end
function plotCompareModelsTrainAndVal(stats::Array{LearningStats}, modelNames::Array{String}, plotName::String = "model-comparison")
plottables, labels = [], []
for stat in stats push!(plottables, stat.trainLoss, stat.valAcc) end
for name in modelNames push!(labels, "$(name) Train Loss", "$(name) Val. Accuracy") end
plt = plot(
1:length(stats[1].trainLoss),
hcat(plottables...),
label = labels, linecolor = [:red :blue :red :blue], linestyle = [:solid :solid :dot :dot],
xlabel = "Epochs", ylabel = "Loss/Accuracy"
)
savefig(plt, "$(plotName).png");
end
function plotCompareModelsTrain(stats::Array{LearningStats}, modelNames::Array{String}, plotName::String = "model-comparison")
plottables, labels = [], []
for stat in stats push!(plottables, stat.trainLoss) end
for name in modelNames push!(labels, "$(name) Train Loss") end
plt = plot(
1:length(stats[1].trainLoss),
hcat(plottables...),
label = labels, linecolor = [:blue :blue], linestyle = [:solid :dot],
xlabel = "Epochs", ylabel = "Loss"
)
savefig(plt, "$(plotName).png");
end
function plotCompareModels(stats::Array{LearningStats}, modelNames::Array{String}, plotName::String = "model-comparison"; trainOnly = false)
if trainOnly
return plotCompareModelsTrain(stats, modelNames, plotName)
else
return plotCompareModelsTrainAndVal(stats, modelNames, plotName)
end
end
function plotGANStats(stats::GANStats, name::String)
plt = plot(1:length(stats.cLoss), hcat(stats.cLoss, stats.gLoss), label = ["Critic Loss", "Generator Loss"], xlabel = "Epochs", ylabel = "Loss")
savefig(plt, "$(name).png");
end
export LearningStats, GANStats, plotLearningStats, plotCompareModels, plotGANStats
end # module stats