Skip to content

Commit

Permalink
gh-104683: Modernise Argument Clinic parameter state machine (#106362)
Browse files Browse the repository at this point in the history
Use enums and pattern matching to make the code more readable.

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
  • Loading branch information
erlend-aasland and AlexWaygood committed Jul 3, 2023
1 parent 7709037 commit 71b4044
Showing 1 changed file with 78 additions and 51 deletions.
129 changes: 78 additions & 51 deletions Tools/clinic/clinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4294,6 +4294,37 @@ def dedent(self, line):
StateKeeper = Callable[[str | None], None]
ConverterArgs = dict[str, Any]

class ParamState(enum.IntEnum):
"""Parameter parsing state.
[ [ a, b, ] c, ] d, e, f=3, [ g, h, [ i ] ] <- line
01 2 3 4 5 6 <- state transitions
"""
# Before we've seen anything.
# Legal transitions: to LEFT_SQUARE_BEFORE or REQUIRED
START = 0

# Left square backets before required params.
LEFT_SQUARE_BEFORE = 1

# In a group, before required params.
GROUP_BEFORE = 2

# Required params, positional-or-keyword or positional-only (we
# don't know yet). Renumber left groups!
REQUIRED = 3

# Positional-or-keyword or positional-only params that now must have
# default values.
OPTIONAL = 4

# In a group, after required params.
GROUP_AFTER = 5

# Right square brackets after required params.
RIGHT_SQUARE_AFTER = 6


class DSLParser:
function: Function | None
state: StateKeeper
Expand Down Expand Up @@ -4331,7 +4362,7 @@ def reset(self) -> None:
self.keyword_only = False
self.positional_only = False
self.group = 0
self.parameter_state = self.ps_start
self.parameter_state: ParamState = ParamState.START
self.seen_positional_with_default = False
self.indent = IndentStack()
self.kind = CALLABLE
Expand Down Expand Up @@ -4726,22 +4757,8 @@ def state_modulename_name(self, line: str | None) -> None:
#
# These rules are enforced with a single state variable:
# "parameter_state". (Previously the code was a miasma of ifs and
# separate boolean state variables.) The states are:
#
# [ [ a, b, ] c, ] d, e, f=3, [ g, h, [ i ] ] <- line
# 01 2 3 4 5 6 <- state transitions
#
# 0: ps_start. before we've seen anything. legal transitions are to 1 or 3.
# 1: ps_left_square_before. left square brackets before required parameters.
# 2: ps_group_before. in a group, before required parameters.
# 3: ps_required. required parameters, positional-or-keyword or positional-only
# (we don't know yet). (renumber left groups!)
# 4: ps_optional. positional-or-keyword or positional-only parameters that
# now must have default values.
# 5: ps_group_after. in a group, after required parameters.
# 6: ps_right_square_after. right square brackets after required parameters.
ps_start, ps_left_square_before, ps_group_before, ps_required, \
ps_optional, ps_group_after, ps_right_square_after = range(7)
# separate boolean state variables.) The states are defined in the
# ParamState class.

def state_parameters_start(self, line: str | None) -> None:
if not self.valid_line(line):
Expand All @@ -4759,8 +4776,8 @@ def to_required(self):
"""
Transition to the "required" parameter state.
"""
if self.parameter_state != self.ps_required:
self.parameter_state = self.ps_required
if self.parameter_state is not ParamState.REQUIRED:
self.parameter_state = ParamState.REQUIRED
for p in self.function.parameters.values():
p.group = -p.group

Expand Down Expand Up @@ -4793,17 +4810,18 @@ def state_parameter(self, line):
self.parse_special_symbol(line)
return

if self.parameter_state in (self.ps_start, self.ps_required):
self.to_required()
elif self.parameter_state == self.ps_left_square_before:
self.parameter_state = self.ps_group_before
elif self.parameter_state == self.ps_group_before:
if not self.group:
match self.parameter_state:
case ParamState.START | ParamState.REQUIRED:
self.to_required()
elif self.parameter_state in (self.ps_group_after, self.ps_optional):
pass
else:
fail("Function " + self.function.name + " has an unsupported group configuration. (Unexpected state " + str(self.parameter_state) + ".a)")
case ParamState.LEFT_SQUARE_BEFORE:
self.parameter_state = ParamState.GROUP_BEFORE
case ParamState.GROUP_BEFORE:
if not self.group:
self.to_required()
case ParamState.GROUP_AFTER | ParamState.OPTIONAL:
pass
case st:
fail(f"Function {self.function.name} has an unsupported group configuration. (Unexpected state {st}.a)")

# handle "as" for parameters too
c_name = None
Expand Down Expand Up @@ -4863,8 +4881,9 @@ def state_parameter(self, line):
name, legacy, kwargs = self.parse_converter(parameter.annotation)

if not default:
if self.parameter_state == self.ps_optional:
fail("Can't have a parameter without a default (" + repr(parameter_name) + ")\nafter a parameter with a default!")
if self.parameter_state is ParamState.OPTIONAL:
fail(f"Can't have a parameter without a default ({parameter_name!r})\n"
"after a parameter with a default!")
if is_vararg:
value = NULL
kwargs.setdefault('c_default', "NULL")
Expand All @@ -4876,8 +4895,8 @@ def state_parameter(self, line):
if is_vararg:
fail("Vararg can't take a default value!")

if self.parameter_state == self.ps_required:
self.parameter_state = self.ps_optional
if self.parameter_state is ParamState.REQUIRED:
self.parameter_state = ParamState.OPTIONAL
default = default.strip()
bad = False
ast_input = f"x = {default}"
Expand Down Expand Up @@ -5001,22 +5020,22 @@ def bad_node(self, node):

if isinstance(converter, self_converter):
if len(self.function.parameters) == 1:
if (self.parameter_state != self.ps_required):
if self.parameter_state is not ParamState.REQUIRED:
fail("A 'self' parameter cannot be marked optional.")
if value is not unspecified:
fail("A 'self' parameter cannot have a default value.")
if self.group:
fail("A 'self' parameter cannot be in an optional group.")
kind = inspect.Parameter.POSITIONAL_ONLY
self.parameter_state = self.ps_start
self.parameter_state = ParamState.START
self.function.parameters.clear()
else:
fail("A 'self' parameter, if specified, must be the very first thing in the parameter block.")

if isinstance(converter, defining_class_converter):
_lp = len(self.function.parameters)
if _lp == 1:
if (self.parameter_state != self.ps_required):
if self.parameter_state is not ParamState.REQUIRED:
fail("A 'defining_class' parameter cannot be marked optional.")
if value is not unspecified:
fail("A 'defining_class' parameter cannot have a default value.")
Expand Down Expand Up @@ -5065,12 +5084,13 @@ def parse_special_symbol(self, symbol):
fail("Function " + self.function.name + " uses '*' more than once.")
self.keyword_only = True
elif symbol == '[':
if self.parameter_state in (self.ps_start, self.ps_left_square_before):
self.parameter_state = self.ps_left_square_before
elif self.parameter_state in (self.ps_required, self.ps_group_after):
self.parameter_state = self.ps_group_after
else:
fail("Function " + self.function.name + " has an unsupported group configuration. (Unexpected state " + str(self.parameter_state) + ".b)")
match self.parameter_state:
case ParamState.START | ParamState.LEFT_SQUARE_BEFORE:
self.parameter_state = ParamState.LEFT_SQUARE_BEFORE
case ParamState.REQUIRED | ParamState.GROUP_AFTER:
self.parameter_state = ParamState.GROUP_AFTER
case st:
fail(f"Function {self.function.name} has an unsupported group configuration. (Unexpected state {st}.b)")
self.group += 1
self.function.docstring_only = True
elif symbol == ']':
Expand All @@ -5079,20 +5099,27 @@ def parse_special_symbol(self, symbol):
if not any(p.group == self.group for p in self.function.parameters.values()):
fail("Function " + self.function.name + " has an empty group.\nAll groups must contain at least one parameter.")
self.group -= 1
if self.parameter_state in (self.ps_left_square_before, self.ps_group_before):
self.parameter_state = self.ps_group_before
elif self.parameter_state in (self.ps_group_after, self.ps_right_square_after):
self.parameter_state = self.ps_right_square_after
else:
fail("Function " + self.function.name + " has an unsupported group configuration. (Unexpected state " + str(self.parameter_state) + ".c)")
match self.parameter_state:
case ParamState.LEFT_SQUARE_BEFORE | ParamState.GROUP_BEFORE:
self.parameter_state = ParamState.GROUP_BEFORE
case ParamState.GROUP_AFTER | ParamState.RIGHT_SQUARE_AFTER:
self.parameter_state = ParamState.RIGHT_SQUARE_AFTER
case st:
fail(f"Function {self.function.name} has an unsupported group configuration. (Unexpected state {st}.c)")
elif symbol == '/':
if self.positional_only:
fail("Function " + self.function.name + " uses '/' more than once.")
self.positional_only = True
# ps_required and ps_optional are allowed here, that allows positional-only without option groups
# REQUIRED and OPTIONAL are allowed here, that allows positional-only without option groups
# to work (and have default values!)
if (self.parameter_state not in (self.ps_required, self.ps_optional, self.ps_right_square_after, self.ps_group_before)) or self.group:
fail("Function " + self.function.name + " has an unsupported group configuration. (Unexpected state " + str(self.parameter_state) + ".d)")
allowed = {
ParamState.REQUIRED,
ParamState.OPTIONAL,
ParamState.RIGHT_SQUARE_AFTER,
ParamState.GROUP_BEFORE,
}
if (self.parameter_state not in allowed) or self.group:
fail(f"Function {self.function.name} has an unsupported group configuration. (Unexpected state {self.parameter_state}.d)")
if self.keyword_only:
fail("Function " + self.function.name + " mixes keyword-only and positional-only parameters, which is unsupported.")
# fixup preceding parameters
Expand Down

0 comments on commit 71b4044

Please sign in to comment.