diff --git a/piptools/scripts/compile.py b/piptools/scripts/compile.py index eac222698..b880b80eb 100755 --- a/piptools/scripts/compile.py +++ b/piptools/scripts/compile.py @@ -20,7 +20,13 @@ from ..repositories import LocalRequirementsRepository, PyPIRepository from ..repositories.base import BaseRepository from ..resolver import Resolver -from ..utils import UNSAFE_PACKAGES, dedup, is_pinned_requirement, key_from_ireq +from ..utils import ( + UNSAFE_PACKAGES, + dedup, + drop_extras, + is_pinned_requirement, + key_from_ireq, +) from ..writer import OutputWriter DEFAULT_REQUIREMENTS_FILE = "requirements.in" @@ -402,6 +408,8 @@ def cli( ) constraints = [req for req in constraints if req.match_markers(extras)] + for req in constraints: + drop_extras(req) log.debug("Using indexes:") with log.indentation(): diff --git a/piptools/utils.py b/piptools/utils.py index d58b5c067..2e92772a9 100644 --- a/piptools/utils.py +++ b/piptools/utils.py @@ -6,6 +6,7 @@ Dict, Iterable, Iterator, + List, Optional, Set, Tuple, @@ -210,6 +211,56 @@ def dedup(iterable: Iterable[_T]) -> Iterable[_T]: return iter(dict.fromkeys(iterable)) +def drop_extras(ireq: InstallRequirement) -> None: + """Remove "extra" markers (PEP-508) from requirement.""" + if ireq.markers is None: + return + ireq.markers._markers = _drop_extras(ireq.markers._markers) + if not ireq.markers._markers: + ireq.markers = None + + +def _drop_extras(markers: List[_T]) -> List[_T]: + # drop `extra` tokens + to_remove: List[int] = [] + for i, token in enumerate(markers): + # operator (and/or) + if isinstance(token, str): + continue + # sub-expression (inside braces) + if isinstance(token, list): + markers[i] = _drop_extras(token) # type: ignore + if not markers[i]: + to_remove.append(i) + continue + # test expression (like `extra == "dev"`) + assert isinstance(token, tuple) + if token[0].value == "extra": + to_remove.append(i) + for i in reversed(to_remove): + markers.pop(i) + + # drop duplicate bool operators (and/or) + to_remove = [] + for i, (token1, token2) in enumerate(zip(markers, markers[1:])): + if not isinstance(token1, str): + continue + if not isinstance(token2, str): + continue + if token1 == "and": + to_remove.append(i) + else: + to_remove.append(i + 1) + for i in reversed(to_remove): + markers.pop(i) + if markers and isinstance(markers[0], str): + markers.pop(0) + if markers and isinstance(markers[-1], str): + markers.pop(-1) + + return markers + + def get_hashes_from_ireq(ireq: InstallRequirement) -> Set[str]: """ Given an InstallRequirement, return a set of string hashes in the format diff --git a/tests/test_cli_compile.py b/tests/test_cli_compile.py index 206f2a9a9..f52281657 100644 --- a/tests/test_cli_compile.py +++ b/tests/test_cli_compile.py @@ -1767,6 +1767,7 @@ def test_input_formats(fake_dists, runner, make_module, fname, content): assert "small-fake-d" not in out.stderr assert "small-fake-e" not in out.stderr assert "small-fake-f" not in out.stderr + assert "extra ==" not in out.stderr @pytest.mark.network @@ -1786,6 +1787,7 @@ def test_one_extra(fake_dists, runner, make_module, fname, content): assert "small-fake-d==0.4" in out.stderr assert "small-fake-e" not in out.stderr assert "small-fake-f" not in out.stderr + assert "extra ==" not in out.stderr @pytest.mark.network @@ -1815,6 +1817,7 @@ def test_multiple_extras(fake_dists, runner, make_module, fname, content): assert "small-fake-d==0.4" in out.stderr assert "small-fake-e==0.5" in out.stderr assert "small-fake-f==0.6" in out.stderr + assert "extra ==" not in out.stderr def test_extras_fail_with_requirements_in(runner, tmpdir): diff --git a/tests/test_utils.py b/tests/test_utils.py index d6f245c3d..becfc878f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,6 +9,7 @@ from piptools.utils import ( as_tuple, dedup, + drop_extras, flat_map, format_requirement, format_specifier, @@ -367,3 +368,54 @@ def test_lookup_table_from_tuples_with_empty_values(): def test_lookup_table_with_empty_values(): assert lookup_table((), operator.itemgetter(0)) == {} + + +@pytest.mark.parametrize( + ("given", "expected"), + ( + ("", None), + ("extra == 'dev'", None), + ("extra == 'dev' or extra == 'test'", None), + ("os_name == 'nt' and extra == 'dev'", "os_name == 'nt'"), + ("extra == 'dev' and os_name == 'nt'", "os_name == 'nt'"), + ("os_name == 'nt' or extra == 'dev'", "os_name == 'nt'"), + ("extra == 'dev' or os_name == 'nt'", "os_name == 'nt'"), + ("(extra == 'dev') or os_name == 'nt'", "os_name == 'nt'"), + ("os_name == 'nt' and (extra == 'dev' or extra == 'test')", "os_name == 'nt'"), + ("os_name == 'nt' or (extra == 'dev' or extra == 'test')", "os_name == 'nt'"), + ("(extra == 'dev' or extra == 'test') or os_name == 'nt'", "os_name == 'nt'"), + ("(extra == 'dev' or extra == 'test') and os_name == 'nt'", "os_name == 'nt'"), + ( + "os_name == 'nt' or (os_name == 'unix' and extra == 'test')", + "os_name == 'nt' or os_name == 'unix'", + ), + ( + "(os_name == 'unix' and extra == 'test') or os_name == 'nt'", + "os_name == 'unix' or os_name == 'nt'", + ), + ( + "(os_name == 'unix' or extra == 'test') and os_name == 'nt'", + "os_name == 'unix' and os_name == 'nt'", + ), + ( + "(os_name == 'unix' and extra == 'test' or python_version < '3.5')" + " or os_name == 'nt'", + "(os_name == 'unix' or python_version < '3.5') or os_name == 'nt'", + ), + ( + "os_name == 'unix' and extra == 'test' or os_name == 'nt'", + "os_name == 'unix' or os_name == 'nt'", + ), + ( + "os_name == 'unix' or extra == 'test' and os_name == 'nt'", + "os_name == 'unix' or os_name == 'nt'", + ), + ), +) +def test_drop_extras(from_line, given, expected): + ireq = from_line(f"test;{given}") + drop_extras(ireq) + if expected is None: + assert ireq.markers is None + else: + assert str(ireq.markers).replace("'", '"') == expected.replace("'", '"')