Skip to content

Commit

Permalink
Make partitions and (semi)standard_tableaux return iterators
Browse files Browse the repository at this point in the history
... the quick and dirty way to have a stable API
  • Loading branch information
joschmitt committed Feb 14, 2024
1 parent 695603e commit 01c7419
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 88 deletions.
84 changes: 47 additions & 37 deletions src/Combinatorics/EnumerativeCombinatorics/partitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,29 @@ end
@doc raw"""
partitions(n::IntegerUnion)
Return a list of all partitions of a non-negative integer `n`, produced in
lexicographically *descending* order.
Return an iterator over all partitions of a non-negative integer `n`, produced
in lexicographically *descending* order.
Using a smaller integer type for `n` (e.g. `Int8`) may increase performance.
The algorithm used is "Algorithm ZS1" by [ZS98](@cite). This algorithm is also
discussed in [Knu11](@cite), Algorithm P (page 392).
# Examples
```jldoctest
julia> partitions(4)
julia> p = partitions(4);
julia> first(p)
[4]
julia> collect(p)
5-element Vector{Partition{Int64}}:
[4]
[3, 1]
[2, 2]
[2, 1, 1]
[1, 1, 1, 1]
julia> partitions(Int8(4)) # using less memory
julia> collect(partitions(Int8(4))) # using less memory
5-element Vector{Partition{Int8}}:
Int8[4]
Int8[3, 1]
Expand All @@ -184,9 +189,9 @@ function partitions(n::T) where T <: IntegerUnion

# Some trivial cases
if n == 0
return Partition{T}[ partition(T[], check = false) ]
return (p for p in Partition{T}[ partition(T[], check = false) ])
elseif n == 1
return Partition{T}[ partition(T[1], check = false) ]
return (p for p in Partition{T}[ partition(T[1], check = false) ])
end

# Now, the algorithm starts
Expand Down Expand Up @@ -222,7 +227,7 @@ function partitions(n::T) where T <: IntegerUnion
end
push!(P, partition(d[1:k], check = false))
end
return P
return (p for p in P)
end

################################################################################
Expand Down Expand Up @@ -287,9 +292,10 @@ end
partitions(m::IntegerUnion, n::IntegerUnion; only_distinct_parts::Bool = false)
partitions(m::IntegerUnion, n::IntegerUnion, l1::IntegerUnion, l2::IntegerUnion; only_distinct_parts::Bool = false)
Return all partitions of a non-negative integer `m` into `n >= 0` parts.
Optionally, a lower bound `l1 >= 0` and an upper bound `l2` for the parts can be
supplied. In this case, the partitions are produced in *decreasing* order.
Return an iterator over all partitions of a non-negative integer `m` into
`n >= 0` parts. Optionally, a lower bound `l1 >= 0` and an upper bound `l2` for
the parts can be supplied. In this case, the partitions are produced in
*decreasing* order.
There are two choices for the parameter `only_distinct_parts`:
* `false`: no further restriction (*default*);
Expand All @@ -300,7 +306,7 @@ The implemented algorithm is "parta" in [RJ76](@cite).
# Examples
All partitions of 7 into 3 parts:
```jldoctest
julia> partitions(7, 3)
julia> collect(partitions(7, 3))
4-element Vector{Partition{Int64}}:
[5, 1, 1]
[4, 2, 1]
Expand All @@ -309,15 +315,15 @@ julia> partitions(7, 3)
```
All partitions of 7 into 3 parts where all parts are between 1 and 4:
```jldoctest
julia> partitions(7, 3, 1, 4)
julia> collect(partitions(7, 3, 1, 4))
3-element Vector{Partition{Int64}}:
[4, 2, 1]
[3, 3, 1]
[3, 2, 2]
```
Same as above but requiring all parts to be distinct:
```jldoctest
julia> partitions(7, 3, 1, 4; only_distinct_parts = true)
julia> collect(partitions(7, 3, 1, 4; only_distinct_parts = true))
1-element Vector{Partition{Int64}}:
[4, 2, 1]
```
Expand All @@ -339,15 +345,15 @@ function partitions(m::T, n::IntegerUnion, l1::IntegerUnion, l2::IntegerUnion; o

# Some trivial cases
if m == 0 && n == 0
return Partition{T}[ partition(T[], check = false) ]
return (p for p in Partition{T}[ partition(T[], check = false) ])
end

if n == 0 || n > m
return Partition{T}[]
return (p for p in Partition{T}[])
end

if l2 < l1
return Partition{T}[]
return (p for p in Partition{T}[])
end

# If l1 == 0 the algorithm parta will actually create lists containing the
Expand Down Expand Up @@ -411,7 +417,7 @@ function partitions(m::T, n::IntegerUnion, l1::IntegerUnion, l2::IntegerUnion; o
end
end

return P
return (p for p in P)
end

function partitions(m::T, n::IntegerUnion; only_distinct_parts::Bool = false) where T <: IntegerUnion
Expand All @@ -420,14 +426,14 @@ function partitions(m::T, n::IntegerUnion; only_distinct_parts::Bool = false) wh

# Special cases
if m == n
return [ partition(T[ 1 for i in 1:m], check = false) ]
return (p for p in [ partition(T[ 1 for i in 1:m], check = false) ])
elseif m < n || n == 0
return Partition{T}[]
return (p for p in Partition{T}[])
elseif n == 1
return [ partition(T[m], check = false) ]
return (p for p in [ partition(T[m], check = false) ])
end

return partitions(m, n, 1, m; only_distinct_parts = only_distinct_parts)
return (p for p in partitions(m, n, 1, m; only_distinct_parts = only_distinct_parts))
end

function partitions(m::T, n::IntegerUnion, v::Vector{T}, mu::Vector{S}) where {T <: IntegerUnion, S <: IntegerUnion}
Expand All @@ -451,15 +457,17 @@ function partitions(m::T, n::IntegerUnion, v::Vector{T}, mu::Vector{S}) where {T

# Special cases
if n == 0
# TODO: I don't understand this distinction here
# (it also makes the function tabe instable)
if m == 0
return [ Partition{T}[] ]
return (p for p in [ Partition{T}[] ])
else
return Partition{T}[]
return (p for p in Partition{T}[])
end
end

if isempty(mu)
return Partition{T}[]
return (p for p in Partition{T}[])
end

#This will be the list of all partitions found.
Expand Down Expand Up @@ -523,7 +531,7 @@ function partitions(m::T, n::IntegerUnion, v::Vector{T}, mu::Vector{S}) where {T

# This is a necessary condition for existence of a partition
if m < 0 || m > n * (lr - ll)
return P #goto b3
return (p for p in P) #goto b3
end

# The following is a condition for when only a single partition
Expand All @@ -534,7 +542,7 @@ function partitions(m::T, n::IntegerUnion, v::Vector{T}, mu::Vector{S}) where {T
# Noticed on Mar 23, 2023.
if m == 0 && x[1] != 0
push!(P, partition(copy(x), check = false))
return P
return (p for p in P)
end

# Now, the actual algorithm starts
Expand Down Expand Up @@ -625,7 +633,7 @@ function partitions(m::T, n::IntegerUnion, v::Vector{T}, mu::Vector{S}) where {T
end #if gotob2
end #while

return P
return (p for p in P)
end

function partitions(m::T, v::Vector{T}, mu::Vector{S}) where {T <: IntegerUnion, S <: IntegerUnion}
Expand All @@ -645,11 +653,12 @@ function partitions(m::T, v::Vector{T}, mu::Vector{S}) where {T <: IntegerUnion,
res = Partition{T}[]

if isempty(v)
return res
return (p for p in res)
end

if m == 0
return [ Partition{T}[] ]
# TODO: I don't understand this return (and it is type instable)
return (p for p in [ Partition{T}[] ])
end

# We will loop over the number of parts.
Expand Down Expand Up @@ -680,16 +689,16 @@ function partitions(m::T, v::Vector{T}, mu::Vector{S}) where {T <: IntegerUnion,
append!(res, partitions(m, n, v, mu))
end

return res
return (p for p in res)
end

@doc raw"""
partitions(m::T, v::Vector{T}) where T <: IntegerUnion
partitions(m::T, v::Vector{T}, mu::Vector{<:IntegerUnion}) where T <: IntegerUnion
partitions(m::T, n::IntegerUnion, v::Vector{T}, mu::Vector{<:IntegerUnion}) where T <: IntegerUnion
Return all partitions of a non-negative integer `m` where each part is an element
in the vector `v` of positive integers.
Return an iterator over all partitions of a non-negative integer `m` where each
part is an element in the vector `v` of positive integers.
It is assumed that the entries in `v` are strictly increasing.
If the optional vector `mu` is supplied, then each `v[i]` occurs a maximum of
Expand All @@ -710,7 +719,7 @@ julia> length(partitions(100, [1, 2, 5, 10, 20, 50]))
All partitions of 100 where the parts are from {1, 2, 5, 10, 20, 50} and each
part is allowed to occur at most twice:
```jldoctest
julia> partitions(100, [1, 2, 5, 10, 20, 50], [2, 2, 2, 2, 2, 2])
julia> collect(partitions(100, [1, 2, 5, 10, 20, 50], [2, 2, 2, 2, 2, 2]))
6-element Vector{Partition{Int64}}:
[50, 50]
[50, 20, 20, 10]
Expand All @@ -722,7 +731,7 @@ julia> partitions(100, [1, 2, 5, 10, 20, 50], [2, 2, 2, 2, 2, 2])
The partitions of 100 into seven parts, where the parts are required to be
elements from {1, 2, 5, 10, 20, 50} and each part is allowed to occur at most twice.
```jldoctest
julia> partitions(100, 7, [1, 2, 5, 10, 20, 50], [2, 2, 2, 2, 2, 2])
julia> collect(partitions(100, 7, [1, 2, 5, 10, 20, 50], [2, 2, 2, 2, 2, 2]))
1-element Vector{Partition{Int64}}:
[50, 20, 20, 5, 2, 2, 1]
```
Expand All @@ -735,11 +744,12 @@ function partitions(m::T, v::Vector{T}) where T <: IntegerUnion
res = Partition{T}[]

if isempty(v)
return res
return (p for p in res)
end

if m == 0
return [ Partition{T}[] ]
# TODO: I don't understand this return (and it is type instable)
return (p for p in [ Partition{T}[] ])
end

# We will loop over the number of parts.
Expand All @@ -755,7 +765,7 @@ function partitions(m::T, v::Vector{T}) where T <: IntegerUnion
append!(res, partitions(m, n, v, mu))
end

return res
return (p for p in res)
end

################################################################################
Expand Down
37 changes: 18 additions & 19 deletions src/Combinatorics/EnumerativeCombinatorics/tableaux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ end
semistandard_tableaux(shape::Partition{T}, max_val::T = sum(shape)) where T <: Integer
semistandard_tableaux(shape::Vector{T}, max_val::T = sum(shape)) where T <: Integer
Return all semistandard Young tableaux of given shape `shape` and filling elements
bounded by `max_val`.
Return an iterator over all semistandard Young tableaux of given shape `shape`
and filling elements bounded by `max_val`.
By default, `max_val` is equal to the sum of the shape partition (the number of
boxes in the Young diagram).
Expand All @@ -355,10 +355,10 @@ function semistandard_tableaux(shape::Partition{T}, max_val::T = sum(shape)) whe
SST = Vector{YoungTableau{T}}()
len = length(shape)
if max_val < len
return SST
return (t for t in SST)
elseif len == 0
push!(SST, young_tableau(Vector{T}[], check = false))
return SST
return (t for t in SST)
end
tab = [Array{T}(fill(i, shape[i])) for i = 1:len]
m = len
Expand All @@ -377,7 +377,7 @@ function semistandard_tableaux(shape::Partition{T}, max_val::T = sum(shape)) whe
m -= 1
n = shape[m]
else
return SST
return (t for t in SST)
end
end

Expand Down Expand Up @@ -418,14 +418,14 @@ end
@doc raw"""
semistandard_tableaux(box_num::T, max_val::T = box_num) where T <: Integer
Return all semistandard Young tableaux consisting of `box_num` boxes and
filling elements bounded by `max_val`.
Return an iterator over all semistandard Young tableaux consisting of `box_num`
boxes and filling elements bounded by `max_val`.
"""
function semistandard_tableaux(box_num::T, max_val::T = box_num) where T <: Integer
@req box_num >= 0 "box_num >= 0 required"
SST = Vector{YoungTableau{T}}()
if max_val <= 0
return SST
return (t for t in SST)
end
shapes = partitions(box_num)

Expand All @@ -435,16 +435,15 @@ function semistandard_tableaux(box_num::T, max_val::T = box_num) where T <: Inte
end
end

return SST
return (t for t in SST)
end


@doc raw"""
semistandard_tableaux(s::Partition{T}, weight::Vector{T}) where T <: Integer
semistandard_tableaux(s::Vector{T}, weight::Vector{T}) where T <: Integer
Return all semistandard Young tableaux with shape `s` and given weight. This
requires that `sum(s) = sum(weight)`.
Return an iterator over all semistandard Young tableaux with shape `s` and given
weight. This requires that `sum(s) = sum(weight)`.
"""
function semistandard_tableaux(s::Vector{T}, weight::Vector{T}) where T <: Integer
n_max = sum(s)
Expand All @@ -453,7 +452,7 @@ function semistandard_tableaux(s::Vector{T}, weight::Vector{T}) where T <: Integ
tabs = Vector{YoungTableau}()
if isempty(s)
push!(tabs, young_tableau(Vector{Int}[], check = false))
return tabs
return (t for t in tabs)
end
ls = length(s)

Expand Down Expand Up @@ -540,7 +539,7 @@ function semistandard_tableaux(s::Vector{T}, weight::Vector{T}) where T <: Integ
end #rec_sst!()

rec_sst!(1)
return tabs
return (t for t in tabs)
end

function semistandard_tableaux(s::Partition{T}, weight::Partition{T}) where T <: Integer
Expand Down Expand Up @@ -620,13 +619,13 @@ end
standard_tableaux(s::Partition)
standard_tableaux(s::Vector{Integer})
Return all standard Young tableaux of a given shape `s`.
Return an iterator over all standard Young tableaux of a given shape `s`.
"""
function standard_tableaux(s::Partition)
tabs = Vector{YoungTableau}()
if isempty(s)
push!(tabs, young_tableau(Vector{Int}[], check = false))
return tabs
return (t for t in tabs)
end
n_max = sum(s)
ls = length(s)
Expand Down Expand Up @@ -669,7 +668,7 @@ function standard_tableaux(s::Partition)
end
end

return tabs
return (t for t in tabs)
end

function standard_tableaux(s::Vector{T}) where T <: Integer
Expand All @@ -679,15 +678,15 @@ end
@doc raw"""
standard_tableaux(n::Integer)
Return all standard Young tableaux with `n` boxes.
Return an iterator over all standard Young tableaux with `n` boxes.
"""
function standard_tableaux(n::Integer)
@req n >= 0 "n >= 0 required"
ST = Vector{YoungTableau}()
for s in partitions(n)
append!(ST, standard_tableaux(s))
end
return ST
return (t for t in ST)
end

################################################################################
Expand Down
Loading

0 comments on commit 01c7419

Please sign in to comment.