Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC/WIP: Conditional 1-D Generators #15023

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ unsafe_convert{T}(::Type{T}, x::T) = x
(::Type{Array{T}}){T}(m::Int, n::Int, o::Int) = Array{T,3}(m, n, o)

# TODO: possibly turn these into deprecations
Array{T,N}(::Type{T}, d::NTuple{N,Int}) = Array{T}(d)
Array{T}(::Type{T}, d::Int...) = Array{T}(d)
Array{T}(::Type{T}, m::Int) = Array{T,1}(m)
Array{T}(::Type{T}, m::Int,n::Int) = Array{T,2}(m,n)
Expand Down
25 changes: 25 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,28 @@ const (:) = Colon()
# For passing constants through type inference
immutable Val{T}
end

immutable Generator{F,I,C}
f::F
iter::I
cond::C
end

Generator{F,I}(f::F,iter::I) = Generator(f,iter,x->true)

start(g::Generator) = start(g.iter)
done(g::Generator, s) = done(g.iter, s)
function next(g::Generator, s)
v, s2 = next(g.iter, s)
v2 = g.f(v)
while !g.cond(v2) && !done(g.iter, s2)
v, s2 = next(g.iter, s2)
v2 = g.f(v)
end
v2, s2
end

function collect(g::Generator)
result = map(g.f, g.iter)
filter!(g.cond, result)
end
52 changes: 52 additions & 0 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,55 @@ eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2))
x = prod_next(p, st)
((x[1][1],x[1][2]...), x[2])
end

immutable GeneratorND{F,I<:AbstractProdIterator}
f::F
iter::I

(::Type{GeneratorND}){F}(f::F, iters...) = (P = product(iters...); new{F,typeof(P)}(f, P))
end

start(g::GeneratorND) = start(g.iter)
done(g::GeneratorND, s) = done(g.iter, s)
function next(g::GeneratorND, s)
v, s2 = next(g.iter, s)
g.f(v...), s2
end

_size(p::Prod2) = (length(p.a), length(p.b))
_size(p::Prod) = (length(p.a), _size(p.b)...)

size(g::GeneratorND) = _size(g.iter)

function collect(g::GeneratorND)
sz = size(g)
if prod(sz) == 0
return Array(Union{}, sz)
end
st = start(g.iter)
A1, st = next(g.iter, st)
first = g.f(A1...)
dest = Array(typeof(first), sz)
dest[1] = first
return map_to!(xs->g.f(xs...), 2, st, dest, g.iter)
end

# special case for 2d
function collect{F,P<:Prod2}(g::GeneratorND{F,P})
f = g.f
a = g.iter.a
b = g.iter.b
sz = size(g)
if prod(sz) == 0
return Array(Union{}, sz)
end
fst = f(first(a), first(b)) # TODO: don't recompute this in the loop
dest = Array(typeof(fst), sz)
for j in b
for i in a
val = f(i, j) # TODO: handle type changes
@inbounds dest[i, j] = val
end
end
return dest
end
49 changes: 35 additions & 14 deletions src/julia-parser.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1397,22 +1397,23 @@
(define (parse-comma-separated-assignments s)
(parse-comma-separated s parse-eq*))

(define (parse-iteration-spec s)
(let ((r (parse-eq* s)))
(cond ((and (pair? r) (eq? (car r) '=)) r)
((eq? r ':) r)
((and (length= r 4) (eq? (car r) 'comparison)
(or (eq? (caddr r) 'in) (eq? (caddr r) '∈)))
`(= ,(cadr r) ,(cadddr r)))
(else
(error "invalid iteration specification")))))

; as above, but allows both "i=r" and "i in r"
(define (parse-comma-separated-iters s)
(let loop ((ranges '()))
(let ((r (parse-eq* s)))
(let ((r (cond ((and (pair? r) (eq? (car r) '=))
r)
((eq? r ':)
r)
((and (length= r 4) (eq? (car r) 'comparison)
(or (eq? (caddr r) 'in) (eq? (caddr r) '∈)))
`(= ,(cadr r) ,(cadddr r)))
(else
(error "invalid iteration specification")))))
(case (peek-token s)
((#\,) (take-token s) (loop (cons r ranges)))
(else (reverse! (cons r ranges))))))))
(let ((r (parse-iteration-spec s)))
(case (peek-token s)
((#\,) (take-token s) (loop (cons r ranges)))
(else (reverse! (cons r ranges)))))))

(define (parse-space-separated-exprs s)
(with-space-sensitive
Expand Down Expand Up @@ -1469,6 +1470,13 @@
(begin (take-token s) (loop (cons nxt lst))))
((eqv? c #\;) (loop (cons nxt lst)))
((equal? c closer) (loop (cons nxt lst)))
((eq? c 'for)
(take-token s)
(let ((gen (parse-generator s nxt #f)))
(if (eqv? (require-token s) #\,)
(take-token s))
(loop (cons gen lst))))
((eq? c 'for) (take-token s) (parse-generator s t closer))
;; newline character isn't detectable here
#;((eqv? c #\newline)
(error "unexpected line break in argument list"))
Expand Down Expand Up @@ -1513,7 +1521,7 @@
(define (parse-comprehension s first closer)
(let ((r (parse-comma-separated-iters s)))
(if (not (eqv? (require-token s) closer))
(error (string "expected " closer))
(error (string "expected \"" closer "\""))
(take-token s))
`(comprehension ,first ,@r)))

Expand All @@ -1523,6 +1531,12 @@
`(dict_comprehension ,@(cdr c))
(error "invalid dict comprehension"))))

(define (parse-generator s first allow-comma)
(let ((r (if allow-comma
(parse-comma-separated-iters s)
(list (parse-iteration-spec s)))))
`(generator ,first ,@r)))

(define (parse-matrix s first closer gotnewline)
(define (fix head v) (cons head (reverse v)))
(define (update-outer v outer)
Expand Down Expand Up @@ -1951,6 +1965,13 @@
`(tuple ,ex)
;; value in parentheses (x)
ex))
((eq? t 'for)
(take-token s)
(let ((gen (parse-generator s ex #t)))
(if (eqv? (require-token s) #\) )
(take-token s)
(error "expected \")\""))
gen))
(else
;; tuple (x,) (x,y) (x...) etc.
(if (eqv? t #\, )
Expand Down
17 changes: 17 additions & 0 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,23 @@
(lower-ccall name RT (cdr argtypes) args))))
e))

'generator
(lambda (e)
(let ((expr (cadr e))
(vars (map cadr (cddr e)))
(ranges (map caddr (cddr e))))
(let* ((names (map (lambda (v) (if (symbol? v) v (gensy))) vars))
(stmts (apply append
(map (lambda (v arg) (if (symbol? v)
'()
`((= ,v ,arg))))
vars names))))
(expand-forms
(expand-binding-forms
`(call (top ,(if (length> ranges 1) 'GeneratorND 'Generator))
(-> (tuple ,@names) (block ,@stmts ,expr))
,@ranges))))))

'comprehension
(lambda (e)
(expand-forms (lower-comprehension #f (cadr e) (cddr e))))
Expand Down