Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Following up on PR 1296, implement both APIs for search_sorted. #1315

Merged
merged 1 commit into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 64 additions & 35 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -591,34 +591,6 @@ def i_to_n(x:Int) -> Maybe Nat =
then Nothing
else Just $ unsafe_i_to_n x

'## Fencepost index sets

struct Post(segment:Type) =
val : Nat

instance Ix(Post segment) given (segment|Ix)
def size'() = size segment + 1
def ordinal(i) = i.val
def unsafe_from_ordinal(i) = Post(i)

def left_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i)

def right_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i + 1)

interface NonEmpty(n|Ix)
first_ix : n

def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))

instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)

instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)

'### Monoid
A [monoid](https://en.wikipedia.org/wiki/Monoid) is a things that have an associative binary operator and an identity element.
This is a very useful and general calls of things.
Expand Down Expand Up @@ -901,6 +873,12 @@ instance Ix(Maybe a) given (a|Ix)
False -> Just $ unsafe_from_ordinal o
True -> Nothing

interface NonEmpty(n|Ix)
first_ix : n

instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)

instance NonEmpty(Bool)
first_ix = unsafe_from_ordinal 0

Expand All @@ -918,6 +896,40 @@ instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix)
instance NonEmpty(Maybe a) given (a|Ix)
first_ix = unsafe_from_ordinal 0

'## Fencepost index sets

struct Post(segment:Type) =
val : Nat

instance Ix(Post segment) given (segment|Ix)
def size'() = size segment + 1
def ordinal(i) = i.val
def unsafe_from_ordinal(i) = Post(i)

def left_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i)

def right_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i + 1)

def left_fence(p:Post n) -> Maybe n given (n|Ix) =
ix = ordinal p
if ix == 0
then Nothing
else Just $ unsafe_from_ordinal $ ix -| 1

def right_fence(p:Post n) -> Maybe n given (n|Ix) =
ix = ordinal p
if ix == size n
then Nothing
else Just $ unsafe_from_ordinal ix

def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))

instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)

def scan(
init:a,
body:(n, a)->(a,b)
Expand Down Expand Up @@ -2016,27 +2028,45 @@ instance Arbitrary(Fin n) given (n)

'### Searching

'returns the highest index `i` such that `xs.i <= x`
'Returns the bucket of `x` assuming boundaries `xs` as a `Post n`.
The boundaries must already be sorted, and are inclusive on the left.

'In other words, if there is an index `i` such that `xs.i <= x`,
returns the `right_post` of the highest such index; otherwise returns
`first_ix : Post n`, which is also the `left_post` of the minimum `i`.

def search_sorted(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
'This is equivalent to the right-biased formulation: if an index `i`
exists such that `x < xs.i`, returns the `left_post` of the least such
`i`, otherwise returns `last_ix : Post n`, i.e., the `right_post` of
the maximum `i`.

def search_sorted(xs:n=>a, x:a) -> Post n given (n|Ix, a|Ord) =
if size n == 0
then Nothing
then first_ix
else if x < xs[from_ordinal 0]
then Nothing
then first_ix
else
low <- with_state(0)
high <- with_state(size n)
_ <- iter
numLeft = n_to_i (get high) - n_to_i (get low)
if numLeft == 1
then Done $ Just $ from_ordinal $ get low
then Done $ right_post $ from_ordinal $ get low
else
centerIx = get low + unsafe_i_to_n (numLeft `idiv` 2)
if x < xs[from_ordinal centerIx]
then high := centerIx
else low := centerIx
Continue

'If `i` exists such that `xs.i == x`, returns `Just` of the largest
such `i`, otherwise returns `Nothing`.

def search_sorted_exact(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
case left_fence(search_sorted(xs, x)) of
Just i -> if xs[i] == x then Just i else Nothing
Nothing -> Nothing

'### min / max etc

def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y)
Expand Down Expand Up @@ -2318,8 +2348,7 @@ def lines(source:String) -> List String =
-- cdf should include 0.0 but not 1.0
def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) =
r = rand key
case search_sorted(cdf, r) of
Just(i) -> i
from_just $ left_fence $ search_sorted(cdf, r)

def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs

Expand Down
4 changes: 2 additions & 2 deletions lib/set.dx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def set_intersect(
UnsafeAsSet(nx, xs) = sx
UnsafeAsSet(ny, ys) = sy
-- This could be done in O(nx + ny) instead of O(nx log ny).
isInYs = \x. case search_sorted ys x of
isInYs = \x. case search_sorted_exact ys x of
Just x -> True
Nothing -> False
AsList(n', intersection) = filter xs isInYs
Expand All @@ -100,7 +100,7 @@ struct Element(set:(Set a)) given (a|Ord) =
-- type), but maybe it's easier to read if it's explicit.
def member(x:a, set:(Set a)) -> Maybe (Element set) given (a|Ord) =
UnsafeAsSet(_, elts) = set
case search_sorted elts x of
case search_sorted_exact elts x of
Just n -> Just $ Element(ordinal n)
Nothing -> Nothing

Expand Down
2 changes: 1 addition & 1 deletion lib/stats.dx
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ instance OrderedDist(Binomial, Nat, Float)
lpdf = for i:(Fin tp1). ln $ density d (ordinal i)
cdf = cdf_for_categorical lpdf
mi = search_sorted cdf q
ordinal $ from_just mi
ordinal $ from_just $ left_fence mi


'### Exponential distribution
Expand Down
6 changes: 6 additions & 0 deletions tests/sort-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ import sort
:p is_sorted $ sort [9, 3, 7, 4, 6, 1, 9, 1, 9, -1, 10, 10, 100, 0]
> True

:p
xs = [1,2,4]
for i:(Fin 6).
search_sorted_exact(xs, ordinal i)
> [Nothing, (Just 0), (Just 1), Nothing, (Just 2), Nothing]

'### Lexical Sorting Tests

:p "aaa" < "bbb"
Expand Down