Skip to content

Commit

Permalink
trade entrypoints for importlib_metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Jul 2, 2024
1 parent 0bbefbc commit e7a69c7
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 25 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ repos:
- id: ruff
args:
- "--fix"
- "--unsafe-fixes"
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
Expand Down
15 changes: 10 additions & 5 deletions ipyparallel/cluster/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import sys
from functools import partial

import entrypoints
try:
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import entry_points

import zmq
from IPython.core.profiledir import ProfileDir
from traitlets import Bool, CaselessStrEnum, Dict, Integer, List, default
Expand Down Expand Up @@ -339,13 +343,14 @@ def _classes_default(self):
launcher_classes = []
for kind in ('controller', 'engine'):
group_name = f'ipyparallel.{kind}_launchers'
group = entrypoints.get_group_named(group_name)
for key, value in group.items():
group = entry_points(group=group_name)
for entrypoint in group:
key = entrypoint.name
try:
cls = value.load()
cls = entrypoint.load()
except Exception as e:
self.log.error(
f"Failed to load entrypoint {group_name}: {key} = {value}\n{e}"
f"Failed to load entrypoint {group_name}: {key} = {entrypoint.value}\n{e}"
)
else:
launcher_classes.append(cls)
Expand Down
20 changes: 12 additions & 8 deletions ipyparallel/cluster/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from tempfile import TemporaryDirectory
from textwrap import indent

import entrypoints
try:
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import entry_points

import psutil
from IPython.utils.path import ensure_dir_exists, get_home_dir
from IPython.utils.text import EvalFormatter
Expand Down Expand Up @@ -2533,25 +2537,25 @@ def find_launcher_class(name, kind):
group_name = 'ipyparallel.controller_launchers'
else:
raise ValueError(f"kind must be 'engine' or 'controller', not {kind!r}")
group = entrypoints.get_group_named(group_name)
group = entry_points(group=group_name)
# make it case-insensitive
registry = {key.lower(): value for key, value in group.items()}
registry = {entrypoint.name.lower(): entrypoint for entrypoint in group}
return registry[name.lower()].load()


@lru_cache
def abbreviate_launcher_class(cls):
"""Abbreviate a launcher class back to its entrypoint name"""
cls_key = f"{cls.__module__}.{cls.__name__}"
cls_key = f"{cls.__module__}:{cls.__name__}"
# allow entrypoint_name attribute in case the definition module
# is not the same as the 'import' module
if getattr(cls, 'entrypoint_name', None):
return getattr(cls, 'entrypoint_name')

for kind in ('controller', 'engine'):
group_name = f'ipyparallel.{kind}_launchers'
group = entrypoints.get_group_named(group_name)
for key, value in group.items():
if f"{value.module_name}.{value.object_name}" == cls_key:
return key.lower()
group = entry_points(group=group_name)
for entrypoint in group:
if entrypoint.value == cls_key:
return entrypoint.name.lower()
return cls_key
11 changes: 8 additions & 3 deletions ipyparallel/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import time
from subprocess import Popen

import entrypoints
try:
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import entry_points

import pytest
from traitlets.config import Config

Expand Down Expand Up @@ -156,9 +160,10 @@ def _wait_one(timeout):
@pytest.mark.parametrize("kind", ("controller", "engine"))
def test_entrypoints(kind):
group_name = f"ipyparallel.{kind}_launchers"
group = entrypoints.get_group_named(group_name)
group = entry_points(group=group_name)
assert len(group) > 2
for key, entrypoint in group.items():
for entrypoint in group:
key = entrypoint.name
# verify entrypoints are valid
cls = entrypoint.load()

Expand Down
18 changes: 10 additions & 8 deletions ipyparallel/traitlets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Custom ipyparallel trait types"""

import entrypoints
try:
from importlib.metadata import entry_points
except ImportError:
from importlib_metadata import entry_points

from traitlets import List, TraitError, Type


Expand All @@ -24,9 +28,7 @@ def help(self):
chunks = [self._original_help]
chunks.append("Currently installed: ")
for key, entry_point in self.load_entry_points().items():
chunks.append(
f" - {key}: {entry_point.module_name}.{entry_point.object_name}"
)
chunks.append(f" - {key}: {entry_point.value}")
return '\n'.join(chunks)

@help.setter
Expand All @@ -35,10 +37,10 @@ def help(self, value):

def load_entry_points(self):
"""Load my entry point group"""
# load the group
group = entrypoints.get_group_named(self.entry_point_group)
# make it case-insensitive
return {key.lower(): value for key, value in group.items()}
return {
entry_point.name.lower(): entry_point
for entry_point in entry_points(group=self.entry_point_group)
}

def validate(self, obj, value):
if isinstance(value, str):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [
urls = {Homepage = "https://ipython.org"}
requires-python = ">=3.8"
dependencies = [
"entrypoints",
"importlib_metadata>=3.6; python_version < '3.10'",
"decorator",
"pyzmq>=18",
"traitlets>=4.3",
Expand Down

0 comments on commit e7a69c7

Please sign in to comment.