Skip to content

Commit

Permalink
[LinearAlgebra] Support more env variables to set OpenBLAS threads (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
giordano committed Jul 21, 2022
1 parent d75843d commit 6009ae9
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,29 @@ function versioninfo(io::IO=stdout)
"JULIA_NUM_THREADS",
"MKL_DYNAMIC",
"MKL_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
# OpenBLAS has a hierarchy of environment variables for setting the
# number of threads, see
# https://github.com/xianyi/OpenBLAS/blob/c43ec53bdd00d9423fc609d7b7ecb35e7bf41b85/README.md#setting-the-number-of-threads-using-environment-variables
("OPENBLAS_NUM_THREADS", "GOTO_NUM_THREADS", "OMP_NUM_THREADS"),
]
printed_at_least_one_env_var = false
print_var(io, indent, name) = println(io, indent, name, " = ", ENV[name])
for name in env_var_names
if haskey(ENV, name)
value = ENV[name]
println(io, indent, name, " = ", value)
printed_at_least_one_env_var = true
if name isa Tuple
# If `name` is a Tuple, then find the first environment which is
# defined, and disregard the following ones.
for nm in name
if haskey(ENV, nm)
print_var(io, indent, nm)
printed_at_least_one_env_var = true
break
end
end
else
if haskey(ENV, name)
print_var(io, indent, name)
printed_at_least_one_env_var = true
end
end
end
if !printed_at_least_one_env_var
Expand All @@ -581,7 +596,8 @@ function __init__()
# register a hook to disable BLAS threading
Base.at_disable_library_threading(() -> BLAS.set_num_threads(1))

if !haskey(ENV, "OPENBLAS_NUM_THREADS")
# https://github.com/xianyi/OpenBLAS/blob/c43ec53bdd00d9423fc609d7b7ecb35e7bf41b85/README.md#setting-the-number-of-threads-using-environment-variables
if !haskey(ENV, "OPENBLAS_NUM_THREADS") && !haskey(ENV, "GOTO_NUM_THREADS") && !haskey(ENV, "OMP_NUM_THREADS")
@static if Sys.isapple() && Base.BinaryPlatforms.arch(Base.BinaryPlatforms.HostPlatform()) == "aarch64"
BLAS.set_num_threads(max(1, Sys.CPU_THREADS))
else
Expand Down

0 comments on commit 6009ae9

Please sign in to comment.