Skip to content

Commit

Permalink
feat(pip_repository): Enable PyPi dep cycles
Browse files Browse the repository at this point in the history
This patch adjusts the pip_repository interface to accept a new
parameter: `composite_libs`, being a list of PyPi package names which
form a cycle and must be installed together.

The intuition behind this design is that a dependency cycle {a <-> b}
is implemented simply as emplacing both a and b at once. Hence a
dependency graph {c -> a, c -> b} has the same effect.

If we modify the installation of a and b to remove their mutual
dependency, and generate a c which dominates a and b, we can then modify
the `requirement()` and `whl_requirement()` helper functions to
recognize the requirements a and b and provide a reference to c instead.
  • Loading branch information
arrdem committed Apr 12, 2023
1 parent c72c7bc commit a03018d
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 39 deletions.
93 changes: 74 additions & 19 deletions python/pip_install/pip_repository.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,6 @@ def _create_repository_execution_environment(rctx):

return env

_BUILD_FILE_CONTENTS = """\
package(default_visibility = ["//visibility:public"])
# Ensure the `requirements.bzl` source can be accessed by stardoc, since users load() from it
exports_files(["requirements.bzl"])
"""

def locked_requirements_label(ctx, attr):
"""Get the preferred label for a locked requirements file based on platform.
Expand Down Expand Up @@ -360,15 +353,16 @@ def _pip_repository_bzlmod_impl(rctx):

repo_name = rctx.attr.name.split("~")[-1]

build_contents = _BUILD_FILE_CONTENTS

if rctx.attr.incompatible_generate_aliases:
_pkg_aliases(rctx, repo_name, bzl_packages)
build_footer = ""
else:
build_contents += _bzlmod_pkg_aliases(repo_name, bzl_packages)
build_footer = _bzlmod_pkg_aliases(repo_name, bzl_packages)

rctx.file("BUILD.bazel", build_contents)
rctx.template("requirements.bzl", rctx.attr._template, substitutions = {
rctx.file("BUILD.bazel", rctx.attr._build_template, substitutions = {
"%%FOOTER%%": build_footer,
})
rctx.template("requirements.bzl", rctx.attr._requirements_template, substitutions = {
"%%ALL_REQUIREMENTS%%": _format_repr_list([
"@{}//{}".format(repo_name, p) if rctx.attr.incompatible_generate_aliases else "@{}_{}//:pkg".format(rctx.attr.name, p)
for p in bzl_packages
Expand Down Expand Up @@ -406,9 +400,12 @@ wheels are fetched/built only for the targets specified by 'build/run/test'.
allow_single_file = True,
doc = "Override the requirements_lock attribute when the host platform is Windows",
),
"_template": attr.label(
"_requirements_template": attr.label(
default = ":pip_repository_requirements_bzlmod.bzl.tmpl",
),
"_build_template": attr.label(
default = ":pip_repository_build.bazel.tmpl",
),
}

pip_repository_bzlmod = repository_rule(
Expand All @@ -422,11 +419,46 @@ def _pip_repository_impl(rctx):
content = rctx.read(requirements_txt)
parsed_requirements_txt = parse_requirements(content)

packages = [(_clean_pkg_name(name), requirement) for name, requirement in parsed_requirements_txt.requirements]
# Apply name normalizations to the composite libs def once
composite_libs = {
_clean_pkg_name(name): [_clean_pkg_name(it) for it in components]
for name, components in rctx.attr.composite_libs.items()
}

bzl_packages = sorted([name for name, _ in packages])
# Ditto for requirements defs
requirements = {
_clean_pkg_name(name): requirement
for name, requirement in parsed_requirements_txt.requirements
}

# Map normalized package names to a composite
composite_mapping = {
name: composite_name
for composite_name, names in composite_libs.items()
for name in names
}

# Normal packages are defined by a single requirement.
# We will deal with composites shortly.
normal_packages = [
(name, requirement)
for name, requirement in requirements.items()
if name not in composite_mapping
]

# Composite packages are a cluster which can only be depended on together
composite_packages = {
_clean_pkg_name(composite_name): [
(rname, requirements[rname])
for rname in composite_components
]
for composite_name, composite_components in rctx.attr.composite_libs.items()
}

bzl_packages = sorted([name for name, _ in requirements.items()])

imports = [
'load("@rules_python//python:defs.bzl", "py_library")',
'load("@rules_python//python/pip_install:pip_repository.bzl", "whl_library")',
]

Expand Down Expand Up @@ -463,8 +495,14 @@ def _pip_repository_impl(rctx):
if rctx.attr.incompatible_generate_aliases:
_pkg_aliases(rctx, rctx.attr.name, bzl_packages)

rctx.file("BUILD.bazel", _BUILD_FILE_CONTENTS)
rctx.template("requirements.bzl", rctx.attr._template, substitutions = {
rctx.template("lib.bzl", rctx.attr._lib_template, substitutions = {
"%%NAME%%": rctx.attr.name,
})
rctx.template("BUILD.bazel", rctx.attr._build_template, substitutions = {
"%%NAME%%": rctx.attr.name,
"%%FOOTER%%": "",
})
rctx.template("requirements.bzl", rctx.attr._requirements_template, substitutions = {
"%%ALL_REQUIREMENTS%%": _format_repr_list([
"@{}//{}".format(rctx.attr.name, p) if rctx.attr.incompatible_generate_aliases else "@{}_{}//:pkg".format(rctx.attr.name, p)
for p in bzl_packages
Expand All @@ -475,13 +513,15 @@ def _pip_repository_impl(rctx):
]),
"%%ANNOTATIONS%%": _format_dict(_repr_dict(annotations)),
"%%CONFIG%%": _format_dict(_repr_dict(config)),
"%%CLUSTERS%%": _format_dict(_repr_dict(composite_packages)),
"%%CLUSTER_MAPPINGS%%": _format_dict(_repr_dict(composite_mapping)),
"%%EXTRA_PIP_ARGS%%": json.encode(options),
"%%IMPORTS%%": "\n".join(sorted(imports)),
"%%NAME%%": rctx.attr.name,
"%%PACKAGES%%": _format_repr_list(
[
("{}_{}".format(rctx.attr.name, p), r)
for p, r in packages
for p, r in normal_packages
],
),
"%%REQUIREMENTS_LOCK%%": str(requirements_txt),
Expand Down Expand Up @@ -602,9 +642,18 @@ wheels are fetched/built only for the targets specified by 'build/run/test'.
allow_single_file = True,
doc = "Override the requirements_lock attribute when the host platform is Windows",
),
"_template": attr.label(
"composite_libs": attr.string_list_dict(
doc = "Groups of requirements which represent dependency cycles and must be treated as composites.",
),
"_requirements_template": attr.label(
default = ":pip_repository_requirements.bzl.tmpl",
),
"_build_template": attr.label(
default = ":pip_repository_build.bazel.tmpl",
),
"_lib_template": attr.label(
default = ":pip_repository_lib.bzl.tmpl",
),
}

pip_repository_attrs.update(**common_attrs)
Expand Down Expand Up @@ -673,6 +722,8 @@ def _whl_library_impl(rctx):
"--annotation",
rctx.path(rctx.attr.annotation),
])
for d in rctx.attr.skip_deps:
args.extend(["--skip", d])

args = _parse_optional_attrs(rctx, args)

Expand Down Expand Up @@ -705,6 +756,10 @@ whl_library_attrs = {
mandatory = True,
doc = "Python requirement string describing the package to make available",
),
"skip_deps": attr.string_list(
doc = "List of requirements to skip due to clustering",
default = [],
),
}

whl_library_attrs.update(**common_attrs)
Expand Down
10 changes: 10 additions & 0 deletions python/pip_install/pip_repository_build.bazel.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load(":lib.bzl", "install_clusters")

package(default_visibility = ["//visibility:public"])

# Ensure the `requirements.bzl` source can be accessed by stardoc, since users load() from it
exports_files(["requirements.bzl"])

install_clusters()

%%FOOTER%%
14 changes: 14 additions & 0 deletions python/pip_install/pip_repository_lib.bzl.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("@rules_python//python:defs.bzl", "py_library")
load(":requirements.bzl", "requirement", "whl_requirement", "requirement_clusters")


def install_clusters():
for cname, components in requirement_clusters.items():
py_library(
name = cname,
deps = [requirement(c, use_clusters=False) for c, _ in components]
)
native.filegroup(
name = "whl_" + cname,
data = [whl_requirement(c, use_clusters=False) for c, _ in components]
)
42 changes: 34 additions & 8 deletions python/pip_install/pip_repository_requirements.bzl.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,40 @@ all_requirements = %%ALL_REQUIREMENTS%%
all_whl_requirements = %%ALL_WHL_REQUIREMENTS%%

_packages = %%PACKAGES%%
_cluster_mappings = %%CLUSTER_MAPPINGS%%
requirement_clusters = %%CLUSTERS%%
_config = %%CONFIG%%
_annotations = %%ANNOTATIONS%%

def _clean_name(name):
return name.replace("-", "_").replace(".", "_").lower()

def requirement(name):
return "@%%NAME%%_" + _clean_name(name) + "//:pkg"
def requirement(name, use_clusters=True):
cname = _clean_name(name)
if cname in _cluster_mappings and use_clusters:
return "@%%NAME%%//:" + _cluster_mappings[cname]
else:
return "@%%NAME%%_" + cname + "//:pkg"

def whl_requirement(name):
return "@%%NAME%%_" + _clean_name(name) + "//:whl"
def whl_requirement(name, use_clusters=True):
cname = _clean_name(name)
if cname in _cluster_mappings and use_clusters:
return "@%%NAME%%//:whl_" + _cluster_mappings[cname]
return "@%%NAME%%_" + cname + "//:whl"

def data_requirement(name):
cname = _clean_name(name)
return "@%%NAME%%_" + _clean_name(name) + "//:data"

def dist_info_requirement(name):
cname = _clean_name(name)
return "@%%NAME%%_" + _clean_name(name) + "//:dist_info"

def entry_point(pkg, script = None):
cname = _clean_name(pkg)
if not script:
script = pkg
return "@%%NAME%%_" + _clean_name(pkg) + "//:rules_python_wheel_entry_point_" + script
return "@%%NAME%%_" + cname + "//:rules_python_wheel_entry_point_" + script

def _get_annotation(requirement):
# This expects to parse `setuptools==58.2.0 --hash=sha256:2551203ae6955b9876741a26ab3e767bb3242dafe86a32a749ea0d78b6792f11`
Expand All @@ -43,10 +55,24 @@ def _get_annotation(requirement):
def install_deps(**whl_library_kwargs):
whl_config = dict(_config)
whl_config.update(whl_library_kwargs)
for name, requirement in _packages:
# Install normal requirements
for name, spec in _packages:
whl_library(
name = name,
requirement = requirement,
annotation = _get_annotation(requirement),
requirement = spec,
annotation = _get_annotation(spec),
**whl_config
)
# And deal with requirement_clusters
for cname, components in requirement_clusters.items():
# Generate the component libraries
cnames = [c[0] for c in components]
for rname, spec in components:
name = "%%NAME%%_" + rname
whl_library(
name = name,
requirement = spec,
annotation = _get_annotation(spec),
skip_deps = cnames,
**whl_config,
)
1 change: 0 additions & 1 deletion python/pip_install/tools/dependency_resolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

1 change: 0 additions & 1 deletion python/pip_install/tools/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

1 change: 0 additions & 1 deletion python/pip_install/tools/lib/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class Annotation(OrderedDict):
"""A python representation of `@rules_python//python:pip.bzl%package_annotation`"""

def __init__(self, content: Dict[str, Any]) -> None:

missing = []
ordered_content = OrderedDict()
for field in (
Expand Down
1 change: 0 additions & 1 deletion python/pip_install/tools/lib/annotations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@


class AnnotationsTestCase(unittest.TestCase):

maxDiff = None

def test_annotations_constructor(self) -> None:
Expand Down
29 changes: 21 additions & 8 deletions python/pip_install/tools/wheel_installer/wheel_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def _generate_copy_commands(src, dest, is_executable=False) -> str:

def _generate_build_file_contents(
name: str,
repo_prefix: str,
dependencies: List[str],
whl_file_deps: List[str],
data_exclude: List[str],
Expand Down Expand Up @@ -241,6 +242,7 @@ def _generate_build_file_contents(
"""\
load("@rules_python//python:defs.bzl", "py_library", "py_binary")
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
load("@{repo_prefix}//:requirements.bzl", "requirement", "whl_requirement")
package(default_visibility = ["//visibility:public"])
Expand Down Expand Up @@ -272,6 +274,7 @@ def _generate_build_file_contents(
)
""".format(
name=name,
repo_prefix=repo_prefix.rstrip("_"),
dependencies=",".join(sorted(dependencies)),
data_exclude=json.dumps(sorted(data_exclude)),
whl_file_label=bazel.WHEEL_FILE_LABEL,
Expand All @@ -297,6 +300,7 @@ def _extract_wheel(
repo_prefix: str,
installation_dir: Path = Path("."),
annotation: Optional[annotation.Annotation] = None,
skip_deps: List[str] = [],
) -> None:
"""Extracts wheel into given directory and creates py_library and filegroup targets.
Expand All @@ -318,15 +322,21 @@ def _extract_wheel(
extras_requested = extras[whl.name] if whl.name in extras else set()
# Packages may create dependency cycles when specifying optional-dependencies / 'extras'.
# Example: github.com/google/etils/blob/a0b71032095db14acf6b33516bca6d885fe09e35/pyproject.toml#L32.
self_edge_dep = set([whl.name])
whl_deps = sorted(whl.dependencies(extras_requested) - self_edge_dep)
to_skip = {bazel.sanitise_name(it, "") for it in [whl.name] + skip_deps}
deps = {bazel.sanitise_name(it, "") for it in whl.dependencies(extras_requested)}
whl_deps = sorted(deps - to_skip)
print(
"While building %s\n\tDeps: %r\n\tSkipping: %r\n\tEffective: %r"
% (
whl.name,
deps,
to_skip,
whl_deps,
)
)

sanitised_dependencies = [
bazel.sanitised_repo_library_label(d, repo_prefix=repo_prefix) for d in whl_deps
]
sanitised_wheel_file_dependencies = [
bazel.sanitised_repo_file_label(d, repo_prefix=repo_prefix) for d in whl_deps
]
sanitised_dependencies = ["requirement(%r)" % d for d in whl_deps]
sanitised_wheel_file_dependencies = ["whl_requirement(%r)" % d for d in whl_deps]

entry_points = []
for name, (module, attribute) in sorted(whl.entry_points().items()):
Expand Down Expand Up @@ -370,6 +380,7 @@ def _extract_wheel(

contents = _generate_build_file_contents(
name=bazel.PY_LIBRARY_LABEL,
repo_prefix=repo_prefix,
dependencies=sanitised_dependencies,
whl_file_deps=sanitised_wheel_file_dependencies,
data_exclude=data_exclude,
Expand All @@ -396,6 +407,7 @@ def main() -> None:
type=annotation.annotation_from_str_path,
help="A json encoded file containing annotations for rendered packages.",
)
parser.add_argument("--skip", action="append", dest="skip_deps", default=[])
arguments.parse_common_args(parser)
args = parser.parse_args()
deserialized_args = dict(vars(args))
Expand Down Expand Up @@ -443,6 +455,7 @@ def main() -> None:
enable_implicit_namespace_pkgs=args.enable_implicit_namespace_pkgs,
repo_prefix=args.repo_prefix,
annotation=args.annotation,
skip_deps=args.skip_deps,
)


Expand Down

0 comments on commit a03018d

Please sign in to comment.