diff --git a/examples/parsing.py b/examples/parsing.py index 3f070c470..bee8d1926 100644 --- a/examples/parsing.py +++ b/examples/parsing.py @@ -1,6 +1,7 @@ """An example illustrating parser-based masking.""" import math import time +import urllib.request import torch from lark import Lark @@ -26,23 +27,37 @@ checkpoint, trust_remote_code=True, revision=revision ).to(device) -input_text = "def " +input_text = "SELECT " inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) +sql_grammar_url = "https://github.com/zbrookle/sql_to_ibis/raw/0e9226da42065940ce21439d490f9fcacadc7f92/sql_to_ibis/grammar/sql.lark" +sql_grammar = "\n".join( + [line.decode("utf-8") for line in urllib.request.urlopen(sql_grammar_url)] +) +with open("sql_grammar.lark", "w") as f: + f.write(sql_grammar) + +sqlparser = Lark.open( + "sql_grammar.lark", + parser="lalr", +) + +pyparser = Lark.open_from_package( + "lark", + "python.lark", + ["grammars"], + parser="lalr", + postlex=PartialPythonIndenter(), + start="file_input", +) + + class ParserLogitsProcessor(LogitsProcessor): """Bias invalid token scores according to a running parse state.""" - def __init__(self): - pyparser = Lark.open_from_package( - "lark", - "python.lark", - ["grammars"], - parser="lalr", - postlex=PartialPythonIndenter(), - start="file_input", - ) - ip = pyparser.parse_interactive("") + def __init__(self, parser): + ip = parser.parse_interactive("") self.parser_state = ip.parser_state self.states_stack = [self.parser_state] self.token_seq = None @@ -97,11 +112,13 @@ def __call__( set_seed(20399) +parser = sqlparser + outputs = model.generate( inputs, max_length=100, temperature=0.1, - logits_processor=LogitsProcessorList([ParserLogitsProcessor()]), + logits_processor=LogitsProcessorList([ParserLogitsProcessor(parser)]), renormalize_logits=True, )