diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a7f19e5..bc327a1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,8 @@ Changelog * Add rule C418 to check for calls passing a dict literal or dict comprehension to ``dict()``. +* Add rule C419 to check for calls passing a list comprehension to ``any()``/``all()``. + 3.11.1 (2023-03-21) ------------------- diff --git a/README.rst b/README.rst index 0936f1d..51d1970 100644 --- a/README.rst +++ b/README.rst @@ -216,3 +216,14 @@ For example: * Rewrite ``dict({})`` as ``{}`` * Rewrite ``dict({"a": 1})`` as ``{"a": 1}`` + +C419 Unnecessary list comprehension in ````\() prevents short-circuiting - rewrite as a generator. +----------------------------------------------------------------------------------------------------------- + +Using a list comprehension inside a call to ``any()``/``all()`` prevents short-circuiting when a ``True`` / ``False`` value is found. +The whole list will be constructed before calling ``any()``/``all()``, potentially wasting work.part-way. +Rewrite to use a generator expression, which can stop part way. +For example: + +* Rewrite ``all([condition(x) for x in iterable])`` as ``all(condition(x) for x in iterable)`` +* Rewrite ``any([condition(x) for x in iterable])`` as ``any(condition(x) for x in iterable)`` diff --git a/src/flake8_comprehensions/__init__.py b/src/flake8_comprehensions/__init__.py index 5d1ede7..b18a8ba 100644 --- a/src/flake8_comprehensions/__init__.py +++ b/src/flake8_comprehensions/__init__.py @@ -47,6 +47,10 @@ def __init__(self, tree: ast.AST) -> None: "C418 Unnecessary {type} passed to dict() - " + "remove the outer call to dict()." ), + "C419": ( + "C419 Unnecessary list comprehension passed to {func}() prevents " + + "short-circuiting - rewrite as a generator." + ), } def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]: @@ -93,13 +97,19 @@ def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]: elif ( num_positional_args == 1 and isinstance(node.args[0], ast.ListComp) - and node.func.id in ("list", "set") + and node.func.id in ("list", "set", "any", "all") ): - msg_key = {"list": "C411", "set": "C403"}[node.func.id] + msg_key = { + "list": "C411", + "set": "C403", + "any": "C419", + "all": "C419", + }[node.func.id] + msg = self.messages[msg_key].format(func=node.func.id) yield ( node.lineno, node.col_offset, - self.messages[msg_key], + msg, type(self), ) diff --git a/tests/test_flake8_comprehensions.py b/tests/test_flake8_comprehensions.py index c41d280..5537cb8 100644 --- a/tests/test_flake8_comprehensions.py +++ b/tests/test_flake8_comprehensions.py @@ -938,3 +938,41 @@ def test_C418_fail(code, failures, flake8_path): (flake8_path / "example.py").write_text(dedent(code)) result = flake8_path.run_flake8() assert result.out_lines == failures + + +@pytest.mark.parametrize( + "code", + [ + "any(num == 3 for num in range(5))", + "all(num == 3 for num in range(5))", + ], +) +def test_C419_pass(code, flake8_path): + (flake8_path / "example.py").write_text(dedent(code)) + result = flake8_path.run_flake8() + assert result.out_lines == [] + + +@pytest.mark.parametrize( + "code,failures", + [ + ( + "any([num == 3 for num in range(5)])", + [ + "./example.py:1:1: C419 Unnecessary list comprehension passed " + + "to any() prevents short-circuiting - rewrite as a generator." + ], + ), + ( + "all([num == 3 for num in range(5)])", + [ + "./example.py:1:1: C419 Unnecessary list comprehension passed " + + "to all() prevents short-circuiting - rewrite as a generator." + ], + ), + ], +) +def test_C419_fail(code, failures, flake8_path): + (flake8_path / "example.py").write_text(dedent(code)) + result = flake8_path.run_flake8() + assert result.out_lines == failures