Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-106368: Argument clinic: add tests for more failure paths #107731

Merged
merged 3 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 94 additions & 14 deletions Lib/test/test_clinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _expect_failure(tc, parser, code, errmsg, *, filename=None, lineno=None):
tc.assertEqual(cm.exception.filename, filename)
if lineno is not None:
tc.assertEqual(cm.exception.lineno, lineno)
return cm.exception


class ClinicWholeFileTest(TestCase):
Expand Down Expand Up @@ -222,6 +223,15 @@ def test_directive_output_print(self):
last_line.startswith("/*[clinic end generated code: output=")
)

def test_directive_wrong_arg_number(self):
raw = dedent("""
/*[clinic input]
preserve foo bar baz eggs spam ham mushrooms
[clinic start generated code]*/
""")
err = "takes 1 positional argument but 8 were given"
self.expect_failure(raw, err)

def test_unknown_destination_command(self):
raw = """
/*[clinic input]
Expand Down Expand Up @@ -600,6 +610,31 @@ def test_directive_output_invalid_command(self):
self.expect_failure(block, err, lineno=2)


class ParseFileUnitTest(TestCase):
def expect_parsing_failure(
self, *, filename, expected_error, verify=True, output=None
):
errmsg = re.escape(dedent(expected_error).strip())
with self.assertRaisesRegex(clinic.ClinicError, errmsg):
clinic.parse_file(filename)

def test_parse_file_no_extension(self) -> None:
self.expect_parsing_failure(
filename="foo",
expected_error="Can't extract file type for file 'foo'"
)

def test_parse_file_strange_extension(self) -> None:
filenames_to_errors = {
"foo.rs": "Can't identify file type for file 'foo.rs'",
"foo.hs": "Can't identify file type for file 'foo.hs'",
"foo.js": "Can't identify file type for file 'foo.js'",
}
for filename, errmsg in filenames_to_errors.items():
with self.subTest(filename=filename):
self.expect_parsing_failure(filename=filename, expected_error=errmsg)


class ClinicGroupPermuterTest(TestCase):
def _test(self, l, m, r, output):
computed = clinic.permute_optional_groups(l, m, r)
Expand Down Expand Up @@ -794,8 +829,8 @@ def parse_function(self, text, signatures_in_block=2, function_index=1):
return s[function_index]

def expect_failure(self, block, err, *, filename=None, lineno=None):
_expect_failure(self, self.parse_function, block, err,
filename=filename, lineno=lineno)
return _expect_failure(self, self.parse_function, block, err,
filename=filename, lineno=lineno)

def checkDocstring(self, fn, expected):
self.assertTrue(hasattr(fn, "docstring"))
Expand Down Expand Up @@ -877,6 +912,41 @@ def test_param_default_expr_named_constant(self):
"""
self.expect_failure(block, err, lineno=2)

def test_param_with_bizarre_default_fails_correctly(self):
template = """
module os
os.access
follow_symlinks: int = {default}
"""
err = "Unsupported expression as default value"
for bad_default_value in (
"{1, 2, 3}",
"3 if bool() else 4",
"[x for x in range(42)]"
):
with self.subTest(bad_default=bad_default_value):
block = template.format(default=bad_default_value)
self.expect_failure(block, err, lineno=2)

def test_unspecified_not_allowed_as_default_value(self):
block = """
module os
os.access
follow_symlinks: int(c_default='MAXSIZE') = unspecified
"""
err = "'unspecified' is not a legal default value!"
exc = self.expect_failure(block, err, lineno=2)
self.assertNotIn('Malformed expression given as default value', str(exc))

def test_malformed_expression_as_default_value(self):
block = """
module os
os.access
follow_symlinks: int(c_default='MAXSIZE') = 1/0
"""
err = "Malformed expression given as default value"
self.expect_failure(block, err, lineno=2)

def test_param_default_expr_binop(self):
err = (
"When you specify an expression ('a + b') as your default value, "
Expand Down Expand Up @@ -1041,6 +1111,28 @@ def test_c_name(self):
""")
self.assertEqual("os_stat_fn", function.c_basename)

def test_base_invalid_syntax(self):
block = """
module os
os.stat
invalid syntax: int = 42
"""
err = dedent(r"""
Function 'stat' has an invalid parameter declaration:
\s+'invalid syntax: int = 42'
""").strip()
with self.assertRaisesRegex(clinic.ClinicError, err):
self.parse_function(block)

def test_param_default_invalid_syntax(self):
block = """
module os
os.stat
x: int = invalid syntax
"""
err = r"Syntax error: 'x = invalid syntax\n'"
self.expect_failure(block, err, lineno=2)

def test_cloning_nonexistent_function_correctly_fails(self):
block = """
cloned = fooooooooooooooooo
Expand Down Expand Up @@ -1414,18 +1506,6 @@ def test_parameters_required_after_star(self):
with self.subTest(block=block):
self.expect_failure(block, err)

def test_parameters_required_after_depr_star(self):
dataset = (
"module foo\nfoo.bar\n * [from 3.14]",
"module foo\nfoo.bar\n * [from 3.14]\nDocstring here.",
"module foo\nfoo.bar\n this: int\n * [from 3.14]",
"module foo\nfoo.bar\n this: int\n * [from 3.14]\nDocstring.",
)
err = "Function 'foo.bar' specifies '* [from 3.14]' without any parameters afterwards."
for block in dataset:
with self.subTest(block=block):
self.expect_failure(block, err)

def test_depr_star_invalid_format_1(self):
block = """
module foo
Expand Down
5 changes: 3 additions & 2 deletions Tools/clinic/clinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5207,13 +5207,14 @@ def bad_node(self, node: ast.AST) -> None:
# but at least make an attempt at ensuring it's a valid expression.
try:
value = eval(default)
if value is unspecified:
fail("'unspecified' is not a legal default value!")
except NameError:
pass # probably a named constant
except Exception as e:
fail("Malformed expression given as default value "
f"{default!r} caused {e!r}")
else:
if value is unspecified:
fail("'unspecified' is not a legal default value!")
erlend-aasland marked this conversation as resolved.
Show resolved Hide resolved
if bad:
fail(f"Unsupported expression as default value: {default!r}")

Expand Down
Loading