Skip to content

Commit

Permalink
Merge pull request #304 from MilesCranmer/more-parallelism
Browse files Browse the repository at this point in the history
40% speedup (for default settings) via more parallelism inside workers
  • Loading branch information
MilesCranmer committed Apr 25, 2024
2 parents cc136e7 + a34f40a commit c32dfb5
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 23 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ jobs:
- "1"
os:
- ubuntu-latest
- macOS-latest
include:
- os: windows-latest
julia-version: "1"
- os: macOS-latest
julia-version: "1"
- os: ubuntu-latest
julia-version: "~1.11.0-0"

Expand Down
8 changes: 6 additions & 2 deletions src/Population.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ function Population(
)
end

function Base.copy(pop::P)::P where {P<:Population}
return Population([copy(pm) for pm in pop.members])
function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}}
copied_members = Vector{PopMember{T,L,N}}(undef, pop.n)
Threads.@threads for i in 1:(pop.n)
copied_members[i] = copy(pop.members[i])
end
return Population(copied_members)
end

# Sample random members of the population, and make a new one
Expand Down
46 changes: 44 additions & 2 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Printf: @printf, @sprintf
using Distributed
using StatsBase: mean

using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions: AbstractExpressionNode, string_tree
using ..UtilsModule: subscriptify
using ..CoreModule: Dataset, Options, MAX_DEGREE, RecordType
using ..ComplexityModule: compute_complexity
Expand Down Expand Up @@ -270,7 +270,7 @@ function get_load_string(; head_node_occupation::Float64, parallelism=:serial)
parallelism == :serial && return ""
out = @sprintf("Head worker occupation: %.1f%%", head_node_occupation * 100)

raise_usage_warning = head_node_occupation > 0.2
raise_usage_warning = head_node_occupation > 0.4
if raise_usage_warning
out *= "."
out *= " This is high, and will prevent efficient resource usage."
Expand Down Expand Up @@ -405,6 +405,48 @@ Base.@kwdef struct SearchState{
record::Base.RefValue{RecordType}
end

function save_to_file(
dominating, nout::Integer, j::Integer, dataset::Dataset{T,L}, options::Options
) where {T,L}
output_file = options.output_file
if nout > 1
output_file = output_file * ".out$j"
end
dominating_n = length(dominating)

complexities = Vector{Int}(undef, dominating_n)
losses = Vector{L}(undef, dominating_n)
strings = Vector{String}(undef, dominating_n)

Threads.@threads for i in 1:dominating_n
member = dominating[i]
complexities[i] = compute_complexity(member, options)
losses[i] = member.loss
strings[i] = string_tree(
member.tree, options; variable_names=dataset.variable_names
)
end

s = let
tmp_io = IOBuffer()

println(tmp_io, "Complexity,Loss,Equation")
for i in 1:dominating_n
println(tmp_io, "$(complexities[i]),$(losses[i]),\"$(strings[i])\"")
end

String(take!(tmp_io))
end

# Write file twice in case exit in middle of filewrite
for out_file in (output_file, output_file * ".bkup")
open(out_file, "w") do io
write(io, s)
end
end
return nothing
end

"""
get_cur_maxsize(; options, total_cycles, cycles_remaining)
Expand Down
3 changes: 2 additions & 1 deletion src/SingleIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using DynamicExpressions:
string_tree,
simplify_tree!,
combine_operators
using ..UtilsModule: @threads_if
using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
using ..ComplexityModule: compute_complexity
using ..PopMemberModule: PopMember, generate_reference
Expand Down Expand Up @@ -108,7 +109,7 @@ function optimize_and_simplify_population(
)::Tuple{P,Float64} where {T,L,D<:Dataset{T,L},P<:Population{T,L}}
array_num_evals = zeros(Float64, pop.n)
do_optimization = rand(pop.n) .< options.optimizer_probability
for j in 1:(pop.n)
@threads_if !(options.deterministic) for j in 1:(pop.n)
if options.should_simplify
tree = pop.members[j].tree
tree = simplify_tree!(tree, options.operators)
Expand Down
19 changes: 2 additions & 17 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ using .SearchUtilsModule:
load_saved_hall_of_fame,
load_saved_population,
construct_datasets,
save_to_file,
get_cur_maxsize,
update_hall_of_fame!

Expand Down Expand Up @@ -916,23 +917,7 @@ function _main_search_loop!(
dominating = calculate_pareto_frontier(state.halls_of_fame[j])

if options.save_to_file
output_file = options.output_file
if nout > 1
output_file = output_file * ".out$j"
end
# Write file twice in case exit in middle of filewrite
for out_file in (output_file, output_file * ".bkup")
open(out_file, "w") do io
println(io, "Complexity,Loss,Equation")
for member in dominating
println(
io,
"$(compute_complexity(member, options)),$(member.loss),\"" *
"$(string_tree(member.tree, options, variable_names=dataset.variable_names))\"",
)
end
end
end
save_to_file(dominating, nout, j, dataset, options)
end
###################################################################
# Migration #######################################################
Expand Down
10 changes: 10 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ function poisson_sample(λ::T) where {T}
return k - 1
end

macro threads_if(flag, ex)
return quote
if $flag
Threads.@threads $ex
else
$ex
end
end |> esc
end

"""
@save_kwargs variable function ... end
Expand Down

0 comments on commit c32dfb5

Please sign in to comment.