Skip to content

Commit

Permalink
Following up on PR 1296, implement both APIs for search_sorted.
Browse files Browse the repository at this point in the history
Context: #1296 (comment)

The more basic API, still named `search_sorted`, returns a `Post n`.
The idea is that the elements of `xs` are fence sections, and we find
the position between them (inclusive on either end) where `x` falls in
the ordering.

In terms of this, we now define `search_sorted_exact` (better name?),
which returns a `Maybe n`, which is the index of an element of `xs`
that equals `x` exactly, or `Nothing` if such does not exist.

Also reorder the prelude slightly to try to both maintain semantic
groupings and respect name resolution dependencies.
  • Loading branch information
axch committed Jun 23, 2023
1 parent e446273 commit 6d23e46
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 38 deletions.
99 changes: 64 additions & 35 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -591,34 +591,6 @@ def i_to_n(x:Int) -> Maybe Nat =
then Nothing
else Just $ unsafe_i_to_n x

'## Fencepost index sets

struct Post(segment:Type) =
val : Nat

instance Ix(Post segment) given (segment|Ix)
def size'() = size segment + 1
def ordinal(i) = i.val
def unsafe_from_ordinal(i) = Post(i)

def left_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i)

def right_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i + 1)

interface NonEmpty(n|Ix)
first_ix : n

def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))

instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)

instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)

'### Monoid
A [monoid](https://en.wikipedia.org/wiki/Monoid) is a things that have an associative binary operator and an identity element.
This is a very useful and general calls of things.
Expand Down Expand Up @@ -901,6 +873,12 @@ instance Ix(Maybe a) given (a|Ix)
False -> Just $ unsafe_from_ordinal o
True -> Nothing

interface NonEmpty(n|Ix)
first_ix : n

instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)

instance NonEmpty(Bool)
first_ix = unsafe_from_ordinal 0

Expand All @@ -918,6 +896,40 @@ instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix)
instance NonEmpty(Maybe a) given (a|Ix)
first_ix = unsafe_from_ordinal 0

'## Fencepost index sets

struct Post(segment:Type) =
val : Nat

instance Ix(Post segment) given (segment|Ix)
def size'() = size segment + 1
def ordinal(i) = i.val
def unsafe_from_ordinal(i) = Post(i)

def left_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i)

def right_post(i:n) -> Post n given (n|Ix) =
unsafe_from_ordinal(n=Post n, ordinal i + 1)

def left_fence(p:Post n) -> Maybe n given (n|Ix) =
ix = ordinal p
if ix == 0
then Nothing
else Just $ unsafe_from_ordinal $ ix -| 1

def right_fence(p:Post n) -> Maybe n given (n|Ix) =
ix = ordinal p
if ix == size n
then Nothing
else Just $ unsafe_from_ordinal ix

def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))

instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)

def scan(
init:a,
body:(n, a)->(a,b)
Expand Down Expand Up @@ -2016,27 +2028,45 @@ instance Arbitrary(Fin n) given (n)

'### Searching

'returns the highest index `i` such that `xs.i <= x`
'Returns the bucket of `x` assuming boundaries `xs` as a `Post n`.
The boundaries must already be sorted, and are inclusive on the left.

'In other words, if there is an index `i` such that `xs.i <= x`,
returns the `right_post` of the highest such index; otherwise returns
`first_ix : Post n`, which is also the `left_post` of the minimum `i`.

def search_sorted(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
'This is equivalent to the right-biased formulation: if an index `i`
exists such that `x < xs.i`, returns the `left_post` of the least such
`i`, otherwise returns `last_ix : Post n`, i.e., the `right_post` of
the maximum `i`.

def search_sorted(xs:n=>a, x:a) -> Post n given (n|Ix, a|Ord) =
if size n == 0
then Nothing
then first_ix
else if x < xs[from_ordinal 0]
then Nothing
then first_ix
else
low <- with_state(0)
high <- with_state(size n)
_ <- iter
numLeft = n_to_i (get high) - n_to_i (get low)
if numLeft == 1
then Done $ Just $ from_ordinal $ get low
then Done $ right_post $ from_ordinal $ get low
else
centerIx = get low + unsafe_i_to_n (numLeft `idiv` 2)
if x < xs[from_ordinal centerIx]
then high := centerIx
else low := centerIx
Continue

'If `i` exists such that `xs.i == x`, returns `Just` of the largest
such `i`, otherwise returns `Nothing`.

def search_sorted_exact(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) =
case left_fence(search_sorted(xs, x)) of
Just i -> if xs[i] == x then Just i else Nothing
Nothing -> Nothing

'### min / max etc

def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y)
Expand Down Expand Up @@ -2318,8 +2348,7 @@ def lines(source:String) -> List String =
-- cdf should include 0.0 but not 1.0
def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) =
r = rand key
case search_sorted(cdf, r) of
Just(i) -> i
from_just $ left_fence $ search_sorted(cdf, r)

def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs

Expand Down
4 changes: 2 additions & 2 deletions lib/set.dx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def set_intersect(
UnsafeAsSet(nx, xs) = sx
UnsafeAsSet(ny, ys) = sy
-- This could be done in O(nx + ny) instead of O(nx log ny).
isInYs = \x. case search_sorted ys x of
isInYs = \x. case search_sorted_exact ys x of
Just x -> True
Nothing -> False
AsList(n', intersection) = filter xs isInYs
Expand All @@ -100,7 +100,7 @@ struct Element(set:(Set a)) given (a|Ord) =
-- type), but maybe it's easier to read if it's explicit.
def member(x:a, set:(Set a)) -> Maybe (Element set) given (a|Ord) =
UnsafeAsSet(_, elts) = set
case search_sorted elts x of
case search_sorted_exact elts x of
Just n -> Just $ Element(ordinal n)
Nothing -> Nothing

Expand Down
2 changes: 1 addition & 1 deletion lib/stats.dx
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ instance OrderedDist(Binomial, Nat, Float)
lpdf = for i:(Fin tp1). ln $ density d (ordinal i)
cdf = cdf_for_categorical lpdf
mi = search_sorted cdf q
ordinal $ from_just mi
ordinal $ from_just $ left_fence mi


'### Exponential distribution
Expand Down
6 changes: 6 additions & 0 deletions tests/sort-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ import sort
:p is_sorted $ sort [9, 3, 7, 4, 6, 1, 9, 1, 9, -1, 10, 10, 100, 0]
> True

:p
xs = [1,2,4]
for i:(Fin 6).
search_sorted_exact(xs, ordinal i)
> [Nothing, (Just 0), (Just 1), Nothing, (Just 2), Nothing]

'### Lexical Sorting Tests

:p "aaa" < "bbb"
Expand Down

0 comments on commit 6d23e46

Please sign in to comment.