Skip to content

Commit

Permalink
Extend the parsing example to include SQL-guided generation
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 6, 2023
1 parent 6eac8ec commit 86b982d
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions examples/parsing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""An example illustrating parser-based masking."""
import math
import time
import urllib.request

import torch
from lark import Lark
Expand All @@ -26,23 +27,37 @@
checkpoint, trust_remote_code=True, revision=revision
).to(device)

input_text = "def "
input_text = "SELECT "
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)


sql_grammar_url = "https://github.com/zbrookle/sql_to_ibis/raw/0e9226da42065940ce21439d490f9fcacadc7f92/sql_to_ibis/grammar/sql.lark"
sql_grammar = "\n".join(
[line.decode("utf-8") for line in urllib.request.urlopen(sql_grammar_url)]
)
with open("sql_grammar.lark", "w") as f:
f.write(sql_grammar)

sqlparser = Lark.open(
"sql_grammar.lark",
parser="lalr",
)

pyparser = Lark.open_from_package(
"lark",
"python.lark",
["grammars"],
parser="lalr",
postlex=PartialPythonIndenter(),
start="file_input",
)


class ParserLogitsProcessor(LogitsProcessor):
"""Bias invalid token scores according to a running parse state."""

def __init__(self):
pyparser = Lark.open_from_package(
"lark",
"python.lark",
["grammars"],
parser="lalr",
postlex=PartialPythonIndenter(),
start="file_input",
)
ip = pyparser.parse_interactive("")
def __init__(self, parser):
ip = parser.parse_interactive("")
self.parser_state = ip.parser_state
self.states_stack = [self.parser_state]
self.token_seq = None
Expand Down Expand Up @@ -97,11 +112,13 @@ def __call__(

set_seed(20399)

parser = sqlparser

outputs = model.generate(
inputs,
max_length=100,
temperature=0.1,
logits_processor=LogitsProcessorList([ParserLogitsProcessor()]),
logits_processor=LogitsProcessorList([ParserLogitsProcessor(parser)]),
renormalize_logits=True,
)

Expand Down

0 comments on commit 86b982d

Please sign in to comment.