Skip to content

Commit

Permalink
add remove clause + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FissoreD committed Sep 17, 2024
1 parent c19dd60 commit ea1b3f6
Show file tree
Hide file tree
Showing 17 changed files with 177 additions and 50 deletions.
1 change: 1 addition & 0 deletions src/API.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,7 @@ module Utils = struct
| Some (`After,x) -> [After x]
| Some (`Before,x) -> [Before x]
| Some (`Replace,x) -> [Replace x]
| Some (`Remove,x) -> [Remove x]
| None -> []) in
[Program.Clause {
Clause.loc = loc;
Expand Down
2 changes: 1 addition & 1 deletion src/API.mli
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ module Utils : sig

(** Hackish, in particular the output should be a compiled program *)
val clause_of_term :
?name:string -> ?graft:([`After | `Before | `Replace] * string) ->
?name:string -> ?graft:([`After | `Before | `Replace | `Remove] * string) ->
depth:int -> Ast.Loc.t -> Data.term -> Ast.program

(** Lifting/restriction/beta (LOW LEVEL, don't use) *)
Expand Down
19 changes: 19 additions & 0 deletions src/bl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ module Array = struct
in
shift t (i+len-1)

let shift_left t i len =
let rec shift t j =
if j = len then t
else shift (set t j (get t (j+1))) (j + 1)
in
shift t i

let rec length t = match !t with Diff(_,_,x) -> length x | Array a -> Array.length a

let of_list l = ref @@ Array (Array.of_list l)
Expand Down Expand Up @@ -147,6 +154,18 @@ let rec replace f x = function
a (* bleah *)
in
aux 0
let rec remove f = function
| BCons (head,tail) when f head -> tail
| BCons (head, tail) -> BCons (head, remove f tail)
| BArray { len; data } as a ->
let rec aux i =
if i < len then
if f data.(i) then BArray { len = len-1; data = Array.shift_left data i len }
else aux (i+1)
else
a (* bleah *)
in
aux 0

let rec insert f x = function
| BCons (head, tail) when f head <= 0 -> BCons (head, BCons(x,tail))
Expand Down
3 changes: 2 additions & 1 deletion src/bl.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ val rcons : 'a -> 'a t -> 'a t
(* O(n) space and time *)
val copy : 'a t -> 'a t

(* These 2 are O(n) time, O(1) space. The test must succeed once *)
(* These 3 are O(n) time, O(1) space. The test must succeed once *)
val replace : ('a -> bool) -> 'a -> 'a t -> 'a t
val remove : ('a -> bool) -> 'a t -> 'a t
val insert : ('a -> int) -> 'a -> 'a t -> 'a t

type 'a scan
Expand Down
30 changes: 17 additions & 13 deletions src/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,9 @@ end = struct (* {{{ *)
| Replace s :: rest ->
if r.insertion <> None then duplicate_err "insertion";
aux_attrs { r with insertion = Some (Replace s) } rest
| Remove s :: rest ->
if r.insertion <> None then duplicate_err "insertion";
aux_attrs { r with insertion = Some (Remove s) } rest
| If s :: rest ->
if r.ifexpr <> None then duplicate_err "if";
aux_attrs { r with ifexpr = Some s } rest
Expand All @@ -693,7 +696,7 @@ end = struct (* {{{ *)
| If s :: rest ->
if r.cifexpr <> None then duplicate_err "if";
aux_chr { r with cifexpr = Some s } rest
| (Before _ | After _ | Replace _ | External | Index _) as a :: _ -> illegal_err a
| (Before _ | After _ | Replace _ | Remove _ | External | Index _) as a :: _ -> illegal_err a
in
let cid = Loc.show loc in
{ c with Chr.attributes = aux_chr { cid; cifexpr = None } attributes }
Expand Down Expand Up @@ -724,7 +727,7 @@ end = struct (* {{{ *)
| Some (Structured.Index _) -> duplicate_err "index"
| Some _ -> error ~loc "external predicates cannot be indexed"
end
| (Before _ | After _ | Replace _ | Name _ | If _) as a :: _ -> illegal_err a
| (Before _ | After _ | Replace _ | Remove _ | Name _ | If _) as a :: _ -> illegal_err a
in
let attributes = aux_tatt None attributes in
let attributes =
Expand Down Expand Up @@ -2080,7 +2083,7 @@ module Assemble : sig
end = struct (* {{{ *)

let compile_clause_attributes ({ Ast.Clause.attributes = { Ast.Structured.id }} as c) =
{ c with Ast.Clause.attributes = { Assembled.id }}
{ c with attributes = { Assembled.id }}

let sort_insertion ~old_rev ~extra:l =
let add s { Ast.Clause.attributes = { Assembled.id }; loc } =
Expand All @@ -2096,18 +2099,19 @@ let compile_clause_attributes ({ Ast.Clause.attributes = { Ast.Structured.id }}
match l, loc_name with
| [],_ -> error ~loc:c.Ast.Clause.loc ("unable to graft this clause: no clause named " ^
match loc_name with
| Ast.Structured.Replace x -> x
| Ast.Structured.After x -> x
| Ast.Structured.Before x -> x)
| { Ast.Clause.attributes = { Assembled.id = Some n }} :: xs,
Ast.Structured.Replace name when n = name ->
c :: xs
| { Ast.Clause.attributes = { Assembled.id = Some n }} as x :: xs,
| Replace x | After x | Before x | Remove x -> x)
| { Ast.Clause.attributes = { Assembled.id = Some n }} as x :: xs, (* AFTER *)
Ast.Structured.After name when n = name ->
c :: x :: xs
| { Ast.Clause.attributes = { Assembled.id = Some n }} as x :: xs,
Ast.Structured.Before name when n = name ->
| { attributes = { Assembled.id = Some n }} as x :: xs, (* BEFORE *)
Before name when n = name ->
x :: c :: xs
| { attributes = { id = Some n }} :: xs, (* REPLACE *)
Replace name when n = name ->
c :: xs
| { attributes = { id = Some n }} :: xs, (* REMOVE *)
Remove name when n = name ->
c :: xs
| x :: xs, _ -> x :: insert loc_name c xs in
let rec aux_sort seen acc = function
| [] -> acc
Expand Down Expand Up @@ -2264,7 +2268,7 @@ let rec constants_of acc = function
let w_symbol_table s f x =
let table = Symbols.compile_table @@ State.get Symbols.table s in
let pp_ctx = { table; uv_names = ref (IntMap.empty,0) } in
Util.set_spaghetti_printer Util.pp_const (R.Pp.pp_constant ~pp_ctx);
Util.set_spaghetti_printer pp_const (R.Pp.pp_constant ~pp_ctx);
f x

(* Compiler passes *)
Expand Down
10 changes: 10 additions & 0 deletions src/discrimination_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ module Trie = struct
map = map |> Ptmap.map (replace p x);
}

let rec remove f = function
| Node { data; other; listTailVariable; map } ->
Node {
data = data |> List.filter (fun x -> not (f x));
other = other |> Option.map (remove f);
listTailVariable = listTailVariable |> Option.map (remove f);
map = map |> Ptmap.map (remove f);
}

let add (a : Path.t) v t =
let max = ref 0 in
let rec ins ~pos = let x = Path.get a pos in function
Expand Down Expand Up @@ -392,6 +401,7 @@ let retrieve cmp_data path { t } =
Bl.of_list @@ List.sort cmp_data r

let replace p x i = { i with t = Trie.replace p x i.t }
let remove keep dt = { dt with t = Trie.remove keep dt.t}

module Internal = struct
let kConstant = kConstant
Expand Down
1 change: 1 addition & 0 deletions src/discrimination_tree.mli
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ val empty_dt : 'b list -> 'a t
val retrieve : ('a -> 'a -> int) -> Path.t -> 'a t -> 'a Bl.scan

val replace : ('a -> bool) -> 'a -> 'a t -> 'a t
val remove : ('a -> bool) -> 'a t -> 'a t

(***********************************************************)
(* Printers *)
Expand Down
3 changes: 2 additions & 1 deletion src/parser/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ type raw_attribute =
| After of string
| Before of string
| Replace of string
| Remove of string
| External
| Index of int list * string option
[@@deriving show]
Expand Down Expand Up @@ -318,7 +319,7 @@ and attribute = {
id : string option;
ifexpr : string option;
}
and insertion = Before of string | After of string | Replace of string
and insertion = Before of string | After of string | Replace of string | Remove of string
and tattribute =
| External
| Index of int list * tindex option
Expand Down
3 changes: 2 additions & 1 deletion src/parser/ast.mli
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ type raw_attribute =
| After of string
| Before of string
| Replace of string
| Remove of string
| External
| Index of int list * string option
[@@ deriving show]
Expand Down Expand Up @@ -213,7 +214,7 @@ and attribute = {
id : string option;
ifexpr : string option;
}
and insertion = Before of string | After of string | Replace of string
and insertion = Before of string | After of string | Replace of string | Remove of string
and cattribute = {
cid : string;
cifexpr : string option
Expand Down
2 changes: 2 additions & 0 deletions src/parser/grammar.mly
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ attribute:
| AFTER; s = STRING { After s }
| BEFORE; s = STRING { Before s }
| REPLACE; s = STRING { Replace s }
| REMOVE; s = STRING { Remove s }
| EXTERNAL { External }
| INDEX; LPAREN; l = nonempty_list(indexing) ; RPAREN; o = option(STRING) { Index (l,o) }

Expand Down Expand Up @@ -400,6 +401,7 @@ constant:
| BEFORE { Func.from_string "before" }
| AFTER { Func.from_string "after" }
| REPLACE { Func.from_string "replace" }
| REMOVE { Func.from_string "remove" }
| INDEX { Func.from_string "index" }
| c = IO { Func.from_string @@ String.make 1 c }
| CUT { Func.cutf }
Expand Down
1 change: 1 addition & 0 deletions src/parser/lexer.mll.in
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ and token = parse
| "after" { AFTER }
| "before" { BEFORE }
| "replace" { REPLACE }
| "remove" { REMOVE }
| "name" { NAME }
| "if" { IF }
| "index" { INDEX }
Expand Down
1 change: 1 addition & 0 deletions src/parser/test_lexer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type t = Tokens.token =
| RULE
| RPAREN
| REPLACE
| REMOVE
| RCURLY
| RBRACKET
| QUOTED of ( string )
Expand Down
1 change: 1 addition & 0 deletions src/parser/tokens.mly
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
%token BEFORE
%token AFTER
%token REPLACE
%token REMOVE
%token NAME
%token INDEX
%token CONS
Expand Down
35 changes: 27 additions & 8 deletions src/runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2625,17 +2625,20 @@ let timestamp_clause clause times time graft =
match graft with
| None -> clause.timestamp <- [time]; []
| Some (Elpi_parser.Ast.Structured.Before x) -> let reference = reference x in clause.timestamp <- reference @ [-time]; reference
| Some (Elpi_parser.Ast.Structured.After x) -> let reference = reference x in clause.timestamp <- reference @ [time]; reference
| Some (Elpi_parser.Ast.Structured.Replace x) -> clause.timestamp <- [time]; let reference = reference x in reference
| Some (After x) -> let reference = reference x in clause.timestamp <- reference @ [time]; reference
| Some (Replace x) -> clause.timestamp <- [time]; let reference = reference x in reference
| Some (Remove x) -> let reference = reference x in reference

let postpend graft reference clause clauses =
match graft with
| None -> Bl.rcons clause clauses
| Some (Elpi_parser.Ast.Structured.Before _)
| Some (Elpi_parser.Ast.Structured.After _) ->
| Some (After _) ->
Bl.insert (fun x -> lex_insertion x.timestamp clause.timestamp) clause clauses
| Some (Elpi_parser.Ast.Structured.Replace _) ->
| Some (Replace _) ->
Bl.replace (fun x -> x.timestamp = reference) clause clauses
| Some (Remove _) -> (* TODO: in this case the clause is ignored... *)
Bl.remove (fun x -> x.timestamp = reference) clauses

let add1clause2 ~depth ~insert ~empty ~copy m graft reference predicate clause = function
| TwoLevelIndex { all_clauses; argno; mode; flex_arg_clauses; arg_idx; } ->
Expand Down Expand Up @@ -2697,7 +2700,7 @@ let add1clause ~depth { idx; time; times } ~time_dir ~insert ~empty ~cons ~copy
let grafting_reference = timestamp_clause clause times time graft in
let times = (* TODO: do this only at compile time *)
match graft with
| Some (Elpi_parser.Ast.Structured.Replace oid) -> StrMap.remove oid times
| Some (Replace oid) -> StrMap.remove oid times
| _ -> times
in
let times =
Expand All @@ -2712,7 +2715,7 @@ let add1clause ~depth { idx; time; times } ~time_dir ~insert ~empty ~cons ~copy
try
(* TODO: do this only at compile time *)
match graft with
| Some (Elpi_parser.Ast.Structured.Replace _) ->
| Some (Replace _) ->
Ptmap.map (function
| TwoLevelIndex {
argno; mode;
Expand All @@ -2725,8 +2728,24 @@ let add1clause ~depth { idx; time; times } ~time_dir ~insert ~empty ~cons ~copy
flex_arg_clauses = insert graft grafting_reference clause flex_arg_clauses;
arg_idx = Ptmap.map (fun l -> insert graft grafting_reference clause l) arg_idx;
}
| BitHash { mode; args; args_idx } -> BitHash { mode; args; args_idx = Ptmap.map (fun l -> insert graft grafting_reference clause l) args_idx }
| IndexWithDiscriminationTree {mode; arg_depths; args_idx; } -> IndexWithDiscriminationTree {mode; arg_depths; args_idx = Discrimination_tree.replace (fun x -> x.timestamp = grafting_reference) clause args_idx; }
| BitHash hash -> BitHash { hash with args_idx = Ptmap.map (fun l -> insert graft grafting_reference clause l) hash.args_idx }
| IndexWithDiscriminationTree dt -> IndexWithDiscriminationTree {dt with args_idx = Discrimination_tree.replace (fun x -> x.timestamp = grafting_reference) clause dt.args_idx; }
) idx
| Some (Remove _) ->
Ptmap.map (function
| TwoLevelIndex {
argno; mode;
all_clauses;
flex_arg_clauses;
arg_idx;
} -> TwoLevelIndex {
argno; mode;
all_clauses = insert graft grafting_reference clause all_clauses;
flex_arg_clauses = insert graft grafting_reference clause flex_arg_clauses;
arg_idx = Ptmap.map (fun l -> insert graft grafting_reference clause l) arg_idx;
}
| BitHash hash -> BitHash { hash with args_idx = Ptmap.map (fun l -> insert graft grafting_reference clause l) hash.args_idx }
| IndexWithDiscriminationTree dt -> IndexWithDiscriminationTree {dt with args_idx = Discrimination_tree.remove (fun x -> x.timestamp = grafting_reference) dt.args_idx; }
) idx
| _ -> add1clause2 ~depth idx ~insert ~empty ~copy graft grafting_reference predicate clause (Ptmap.find predicate idx);
with
Expand Down
45 changes: 43 additions & 2 deletions src/test_discrimination_tree.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
open Elpi.Internal.Discrimination_tree
module DT = Elpi.Internal.Discrimination_tree
open Internal

let test ~expected found =
if expected <> found then failwith (Format.asprintf "Test DT error: Expected %d, but found %d" expected found)


let () = assert (k_of (mkConstant ~safe:false ~data:~-17 ~arity:0) == kConstant)
let () = assert (k_of mkVariable == kVariable)
let () = assert (k_of mkLam == kOther)
Expand Down Expand Up @@ -54,7 +59,7 @@ let () =
Format.printf " Retrived clause number is %d\n%!" retrived_nb;
(* let pp_sep = fun f _ -> Format.pp_print_string f " " in *)
(* Format.printf " Found instances are %a\n%!" (Format.pp_print_list ~pp_sep Format.pp_print_int) retrived; *)
if retrived_nb <> nb then failwith (Format.asprintf "Test DT error: Expected %d clauses, %d found" nb retrived_nb);
test retrived_nb nb;
if (Elpi.Internal.Bl.to_list retrived |> List.sort Int.compare |> List.rev) <> (retrived |> Elpi.Internal.Bl.to_list) then failwith "Test DT error: resultin list is not correctly ordered"
in

Expand All @@ -70,4 +75,40 @@ let () =
test [p2; p3; p4; p5; p6] p1 mkOutputMode 3;
test [p2; p3; p4; p5; p6] p1 mkInputMode 1;
test [p1; p2; p3; p4; p5; p6; p8] p7 mkOutputMode 3;
test [p1; p2; p3; p4; p5; p6; p8] p7 mkInputMode 2;
test [p1; p2; p3; p4; p5; p6; p8] p7 mkInputMode 2

let () =
let get_length dt path = DT.retrieve compare path !dt |> Elpi.Internal.Bl.length in
let remove dt e = dt := DT.remove (fun x -> x = e) !dt in
let index dt path v = dt := DT.index !dt path v in

let constA = mkConstant ~safe:false ~data:~-1 ~arity:~-0 in (* a *)
let p1 = [mkListHead; constA; mkListTailVariable; constA] in
let p2 = [mkListHead; constA; mkListTailVariable; constA; constA] in

let p1Index = Path.of_list p1 in (* path for indexing *)
let p1Retr = mkInputMode :: p1 |> Path.of_list in (* path for retrival *)

let p2Index = Path.of_list p2 in (* path for indexing *)
let p2Retr = mkInputMode :: p2 |> Path.of_list in (* path for retrival *)

let dt = DT.empty_dt (List.init 0 Fun.id) |> ref in
index dt p1Index 100; index dt p1Index 200;
index dt p2Index 200; index dt p2Index 200;

Printf.printf "Test remove 1\n";
test ~expected:2 (get_length dt p1Retr);
test ~expected:2 (get_length dt p2Retr);

Printf.printf "Test remove 2\n";
remove dt 100;
test ~expected:1 (get_length dt p1Retr);
test ~expected:2 (get_length dt p2Retr);

Printf.printf "Test remove 3\n";
remove dt 100;
test ~expected:1 (get_length dt p1Retr);

Printf.printf "Test remove 4\n";
remove dt 200;
test ~expected:0 (get_length dt p1Retr)
Loading

0 comments on commit ea1b3f6

Please sign in to comment.