Skip to content

Commit

Permalink
Merge pull request #1550 from stanfordnlp/interpreter_patches
Browse files Browse the repository at this point in the history
patches python interpreter to support additional syntax
  • Loading branch information
okhat authored Sep 27, 2024
2 parents b1631ae + fe7f3a0 commit b3ecf87
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 31 deletions.
1 change: 0 additions & 1 deletion dspy/predict/program_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def execute_code(self, code):
interpreter = PythonInterpreter(action_space={"print": print}, import_white_list=self.import_white_list)
try:
output = str(code_prompt.execute(interpreter=interpreter)[0])
print
return code, output, None
except Exception as e:
return code, None, str(e)
Expand Down
237 changes: 207 additions & 30 deletions dspy/primitives/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, action_space: Dict[str, Any],
self.action_space = action_space
self.state = self.action_space.copy()
self.fuzz_state: Dict[str, Any] = {}
self.import_white_list = import_white_list or []
self.import_white_list = import_white_list or ["math", "random", "datetime", "time", "string", "collections", "itertools", "functools", "typing", "enum", "json", "ast"] #default imports

def execute(self, code: str, state: Optional[Dict[str, Any]] = None,
fuzz_state: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -183,6 +183,8 @@ def _execute_ast(self, expression: ast.AST) -> Any:
elif isinstance(expression, ast.BinOp):
# Binary Operator -> return the result value
return self._execute_binop(expression)
elif isinstance(expression, ast.BoolOp):
return self._execute_condition(expression)
elif isinstance(expression, ast.Call):
# Function call -> return the value of the function call
return self._execute_call(expression)
Expand Down Expand Up @@ -212,9 +214,13 @@ def _execute_ast(self, expression: ast.AST) -> Any:
elif isinstance(expression, ast.FunctionDef):
self.state[expression.name] = expression
return None
elif isinstance(expression, ast.GeneratorExp):
return self._execute_generatorexp(expression)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return self._execute_if(expression)
elif isinstance(expression, ast.IfExp):
return self._execute_ifexp(expression)
elif isinstance(expression, ast.Import):
# Import -> add imported names in self.state and return None.
self._execute_import(expression)
Expand All @@ -228,6 +234,8 @@ def _execute_ast(self, expression: ast.AST) -> Any:
elif isinstance(expression, ast.JoinedStr):
return "".join(
[str(self._execute_ast(v)) for v in expression.values])
elif isinstance(expression, ast.Lambda):
return self._execute_lambda(expression)
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [self._execute_ast(elt) for elt in expression.elts]
Expand All @@ -242,8 +250,27 @@ def _execute_ast(self, expression: ast.AST) -> Any:
elif isinstance(expression, ast.Tuple):
return tuple([self._execute_ast(elt) for elt in expression.elts])
elif isinstance(expression, ast.UnaryOp):
# Binary Operator -> return the result value
return self._execute_unaryop(expression)
elif isinstance(expression, ast.While):
return self._execute_while(expression)
elif isinstance(expression, ast.ListComp):
return self._execute_listcomp(expression)
elif isinstance(expression, ast.DictComp):
return self._execute_dictcomp(expression)
elif isinstance(expression, ast.SetComp):
return self._execute_setcomp(expression)
elif isinstance(expression, ast.Break):
raise BreakException()
elif isinstance(expression, ast.Continue):
raise ContinueException()
elif isinstance(expression, ast.Try):
return self._execute_try(expression)
elif isinstance(expression, ast.Raise):
return self._execute_raise(expression)
elif isinstance(expression, ast.Pass):
return None
elif isinstance(expression, ast.Assert):
return self._execute_assert(expression)
else:
# For now we refuse anything else. Let's add things as we need
# them.
Expand Down Expand Up @@ -353,39 +380,47 @@ def _execute_condition(self, condition):
elif isinstance(condition.op, ast.Or):
results = [self._execute_ast(value) for value in condition.values]
return any(results)
else: #TODO - add any other BoolOps missing
else:
raise InterpreterError(f"Boolean operator {condition.op} is not supported")
elif isinstance(condition, ast.Compare):
if len(condition.ops) > 1:
raise InterpreterError("Cannot evaluate conditions with multiple operators")
if len(condition.ops) > 1:
raise InterpreterError(
"Cannot evaluate conditions with multiple operators")
left = self._execute_ast(condition.left)
comparator = condition.ops[0]
right = self._execute_ast(condition.comparators[0])
if isinstance(comparator, ast.Eq):
return left == right
elif isinstance(comparator, ast.NotEq):
return left != right
elif isinstance(comparator, ast.Lt):
return left < right
elif isinstance(comparator, ast.LtE):
return left <= right
elif isinstance(comparator, ast.Gt):
return left > right
elif isinstance(comparator, ast.GtE):
return left >= right
elif isinstance(comparator, ast.Is):
return left is right
elif isinstance(comparator, ast.IsNot):
return left is not right
elif isinstance(comparator, ast.In):
return left in right
elif isinstance(comparator, ast.NotIn):
return left not in right
left = self._execute_ast(condition.left)
comparator = condition.ops[0]
right = self._execute_ast(condition.comparators[0])
if isinstance(comparator, ast.Eq):
return left == right
elif isinstance(comparator, ast.NotEq):
return left != right
elif isinstance(comparator, ast.Lt):
return left < right
elif isinstance(comparator, ast.LtE):
return left <= right
elif isinstance(comparator, ast.Gt):
return left > right
elif isinstance(comparator, ast.GtE):
return left >= right
elif isinstance(comparator, ast.Is):
return left is right
elif isinstance(comparator, ast.IsNot):
return left is not right
elif isinstance(comparator, ast.In):
return left in right
elif isinstance(comparator, ast.NotIn):
return left not in right
else:
raise InterpreterError("Unsupported comparison operator")
elif isinstance(condition, ast.UnaryOp):
return self._execute_unaryop(condition)
elif isinstance(condition, ast.Name):
return bool(self._execute_ast(condition))
elif isinstance(condition, ast.Call):
return bool(self._execute_ast(condition))
elif isinstance(condition, ast.Constant):
return bool(condition.value)
else:
raise InterpreterError("Unsupported condition type")
raise InterpreterError(f"Unsupported condition type: {type(condition).__name__}")


def _execute_if(self, if_statement: ast.If):
result = None
Expand All @@ -400,6 +435,13 @@ def _execute_if(self, if_statement: ast.If):
if line_result is not None:
result = line_result
return result

def _execute_ifexp(self, ifexp: ast.IfExp) -> Any:
test_result = self._execute_condition(ifexp.test)
if test_result:
return self._execute_ast(ifexp.body)
else:
return self._execute_ast(ifexp.orelse)

def _execute_for(self, for_statement: ast.For):
result = None
Expand Down Expand Up @@ -427,6 +469,16 @@ def _execute_import_from(self, import_from: ast.ImportFrom):
imported_module = importlib.import_module(import_from.module)
alias = import_name.asname or import_name.name
self.state[alias] = getattr(imported_module, import_name.name)

def _execute_lambda(self, lambda_node: ast.Lambda) -> Any:
def lambda_function(*args):
old_state = self.state.copy()
for param, arg in zip(lambda_node.args.args, args):
self.state[param.arg] = arg
result = self._execute_ast(lambda_node.body)
self.state = old_state # Restore the state
return result
return lambda_function

def _validate_import(self, full_name: str):
tmp_name = ""
Expand Down Expand Up @@ -465,6 +517,12 @@ def _execute_binop(self, binop: ast.BinOp):
return left << right
elif isinstance(operator, ast.RShift):
return left >> right
elif isinstance(operator, ast.BitAnd):
return left & right
elif isinstance(operator, ast.BitOr):
return left | right
elif isinstance(operator, ast.BitXor):
return left ^ right
elif isinstance(operator, ast.MatMult):
return left @ right
else:
Expand All @@ -480,8 +538,127 @@ def _execute_unaryop(self, unaryop: ast.UnaryOp):
return -operand
elif isinstance(operator, ast.Not):
return not operand
elif isinstance(operator, ast.Invert):
return ~operand
else:
raise InterpreterError(f"Operator not supported: {operator}")

def _execute_listcomp(self, comp: ast.ListComp):
return [self._execute_comp(comp.elt, comp.generators)]

def _execute_dictcomp(self, comp: ast.DictComp):
return {self._execute_comp(comp.key, comp.generators): self._execute_comp(comp.value, comp.generators)}

def _execute_setcomp(self, comp: ast.SetComp):
return {self._execute_comp(comp.elt, comp.generators)}

def _execute_comp(self, elt, generators):
if not generators:
return self._execute_ast(elt)
gen = generators[0]
result = []
for value in self._execute_ast(gen.iter):
self._assign(gen.target, value)
if all(self._execute_condition(if_cond) for if_cond in gen.ifs):
result.extend(self._execute_comp(elt, generators[1:]))
return result

def _execute_generatorexp(self, genexp: ast.GeneratorExp):
def generator():
for value in self._execute_comp(genexp.elt, genexp.generators):
yield value
return generator()

def _execute_while(self, while_statement: ast.While):
result = None
while self._execute_condition(while_statement.test):
for line in while_statement.body:
line_result = self._execute_ast(line)
if line_result is not None:
result = line_result
if isinstance(line, (ast.Break, ast.Continue)):
break
else:
continue
break
return result

def _execute_for(self, for_statement: ast.For):
class BreakException(Exception):
pass

class ContinueException(Exception):
pass
result = None
try:
for value in self._execute_ast(for_statement.iter):
self._assign(for_statement.target, value)
try:
for line in for_statement.body:
line_result = self._execute_ast(line)
if line_result is not None:
result = line_result
except ContinueException:
continue
except BreakException:
pass
return result

def _execute_while(self, while_statement: ast.While):
class BreakException(Exception):
pass

class ContinueException(Exception):
pass
result = None
try:
while self._execute_condition(while_statement.test):
try:
for line in while_statement.body:
line_result = self._execute_ast(line)
if line_result is not None:
result = line_result
except ContinueException:
continue
except BreakException:
pass
return result

def _execute_try(self, try_statement: ast.Try):
try:
for line in try_statement.body:
self._execute_ast(line)
except Exception as e:
handled = False
for handler in try_statement.handlers:
if handler.type is None or isinstance(e, self._execute_ast(handler.type)):
if handler.name:
self.state[handler.name.id] = e
for line in handler.body:
self._execute_ast(line)
handled = True
break
if not handled:
raise
finally:
for line in try_statement.finalbody:
self._execute_ast(line)

def _execute_raise(self, raise_statement: ast.Raise):
if raise_statement.exc:
exception = self._execute_ast(raise_statement.exc)
raise exception
else:
raise

def _execute_assert(self, assert_statement: ast.Assert):
test_result = self._execute_condition(assert_statement.test)
if not test_result:
if assert_statement.msg:
msg = self._execute_ast(assert_statement.msg)
raise AssertionError(msg)
else:
raise AssertionError

def _get_value_from_state(self, key: str) -> Any:
if key in self.state:
Expand Down

0 comments on commit b3ecf87

Please sign in to comment.