Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate Dask Cluster Names #871

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion dask_kubernetes/common/objects.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""
Convenience functions for creating pod templates.
"""

import copy
import json
from collections import namedtuple

from kubernetes import client
from kubernetes.client.configuration import Configuration

from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME
from dask_kubernetes.constants import (
KUBECLUSTER_CONTAINER_NAME,
MAX_CLUSTER_NAME_LEN,
VALID_CLUSTER_NAME,
)
from dask_kubernetes.exceptions import ValidationError

_FakeResponse = namedtuple("_FakeResponse", ["data"])

Expand Down Expand Up @@ -365,3 +371,16 @@ def clean_pdb_template(pdb_template):
pdb_template.spec.selector = client.V1LabelSelector()

return pdb_template


def validate_cluster_name(cluster_name: str) -> None:
"""Raise exception if cluster name is too long and/or has invalid characters"""
if not VALID_CLUSTER_NAME.match(cluster_name):
raise ValidationError(
message=(
f"The DaskCluster {cluster_name} is invalid: a lowercase RFC 1123 subdomain must "
"consist of lower case alphanumeric characters, '-' or '.', and must start "
"and end with an alphanumeric character. DaskCluster name must also be under "
f"{MAX_CLUSTER_NAME_LEN} characters."
)
)
26 changes: 24 additions & 2 deletions dask_kubernetes/common/tests/test_objects.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dask_kubernetes.common.objects import make_pod_from_dict
from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME
import pytest

from dask_kubernetes.common.objects import make_pod_from_dict, validate_cluster_name
from dask_kubernetes.constants import KUBECLUSTER_CONTAINER_NAME, MAX_CLUSTER_NAME_LEN
from dask_kubernetes.exceptions import ValidationError


def test_make_pod_from_dict():
Expand Down Expand Up @@ -64,3 +67,22 @@ def test_make_pod_from_dict_default_container_name():
assert pod.spec.containers[0].name == "dask-0"
assert pod.spec.containers[1].name == "sidecar"
assert pod.spec.containers[2].name == "dask-2"


@pytest.mark.parametrize(
"cluster_name",
[
(MAX_CLUSTER_NAME_LEN + 1) * "a",
"invalid.chars.in.name",
],
)
def test_validate_cluster_name_raises_on_invalid_name(
cluster_name,
):

with pytest.raises(ValidationError):
validate_cluster_name(cluster_name)


def test_validate_cluster_name_success_on_valid_name():
assert validate_cluster_name("valid-cluster-name-123") is None
10 changes: 10 additions & 0 deletions dask_kubernetes/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,11 @@
import re

KUBECLUSTER_CONTAINER_NAME = "dask-container"
KUBERNETES_MAX_RESOURCE_NAME_LENGTH = 63
SCHEDULER_NAME_TEMPLATE = "{cluster_name}-scheduler"
MAX_CLUSTER_NAME_LEN = KUBERNETES_MAX_RESOURCE_NAME_LENGTH - len(
SCHEDULER_NAME_TEMPLATE.format(cluster_name="")
)
VALID_CLUSTER_NAME = re.compile(
rf"^(?=.{{,{MAX_CLUSTER_NAME_LEN}}}$)[a-z0-9]([-a-z0-9]*[a-z0-9])?$"
)
7 changes: 7 additions & 0 deletions dask_kubernetes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ class CrashLoopBackOffError(Exception):

class SchedulerStartupError(Exception):
"""Scheduler failed to start."""


class ValidationError(Exception):
"""Manifest validation exception"""

def __init__(self, message: str) -> None:
self.message = message
15 changes: 12 additions & 3 deletions dask_kubernetes/operator/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from importlib_metadata import entry_points
from kr8s.asyncio.objects import Deployment, Pod, Service

from dask_kubernetes.common.objects import validate_cluster_name
from dask_kubernetes.constants import SCHEDULER_NAME_TEMPLATE
from dask_kubernetes.exceptions import ValidationError
from dask_kubernetes.operator._objects import (
DaskAutoscaler,
DaskCluster,
Expand Down Expand Up @@ -75,7 +78,7 @@ def build_scheduler_deployment_spec(
}
)
metadata = {
"name": f"{cluster_name}-scheduler",
"name": SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name),
"labels": labels,
"annotations": annotations,
}
Expand Down Expand Up @@ -107,7 +110,7 @@ def build_scheduler_service_spec(cluster_name, spec, annotations, labels):
"apiVersion": "v1",
"kind": "Service",
"metadata": {
"name": f"{cluster_name}-scheduler",
"name": SCHEDULER_NAME_TEMPLATE.format(cluster_name=cluster_name),
"labels": labels,
"annotations": annotations,
},
Expand Down Expand Up @@ -273,6 +276,12 @@ async def daskcluster_create(name, namespace, logger, patch, **kwargs):
This allows us to track that the operator is running.
"""
logger.info(f"DaskCluster {name} created in {namespace}.")
try:
validate_cluster_name(name)
except ValidationError as validation_exc:
patch.status["phase"] = "Error"
raise kopf.PermanentError(validation_exc.message)

patch.status["phase"] = "Created"


Expand Down Expand Up @@ -599,7 +608,7 @@ async def daskworkergroup_replica_update(
if workers_needed < 0:
worker_ids = await retire_workers(
n_workers=-workers_needed,
scheduler_service_name=f"{cluster_name}-scheduler",
scheduler_service_name=SCHEDULER_NAME_TEMPLATE.format(cluster_name),
worker_group_name=name,
namespace=namespace,
logger=logger,
Expand Down
63 changes: 59 additions & 4 deletions dask_kubernetes/operator/controller/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dask.distributed import Client
from kr8s.asyncio.objects import Deployment, Pod, Service

from dask_kubernetes.constants import MAX_CLUSTER_NAME_LEN
from dask_kubernetes.operator._objects import DaskCluster, DaskJob, DaskWorkerGroup
from dask_kubernetes.operator.controller import (
KUBERNETES_DATETIME_FORMAT,
Expand All @@ -22,17 +23,32 @@

_EXPECTED_ANNOTATIONS = {"test-annotation": "annotation-value"}
_EXPECTED_LABELS = {"test-label": "label-value"}
DEFAULT_CLUSTER_NAME = "simple"


@pytest.fixture()
def gen_cluster(k8s_cluster, ns):
def gen_cluster_manifest(tmp_path):
def factory(cluster_name=DEFAULT_CLUSTER_NAME):
original_manifest_path = os.path.join(DIR, "resources", "simplecluster.yaml")
with open(original_manifest_path, "r") as original_manifest_file:
manifest = yaml.safe_load(original_manifest_file)

manifest["metadata"]["name"] = cluster_name
new_manifest_path = tmp_path / "cluster.yaml"
new_manifest_path.write_text(yaml.safe_dump(manifest))
return tmp_path

return factory


@pytest.fixture()
def gen_cluster(k8s_cluster, ns, gen_cluster_manifest):
"""Yields an instantiated context manager for creating/deleting a simple cluster."""

@asynccontextmanager
async def cm():
cluster_path = os.path.join(DIR, "resources", "simplecluster.yaml")
cluster_name = "simple"
async def cm(cluster_name=DEFAULT_CLUSTER_NAME):

cluster_path = gen_cluster_manifest(cluster_name)
# Create cluster resource
k8s_cluster.kubectl("apply", "-n", ns, "-f", cluster_path)
while cluster_name not in k8s_cluster.kubectl(
Expand Down Expand Up @@ -687,3 +703,42 @@ async def test_object_dask_job(k8s_cluster, kopf_runner, gen_job):

cluster = await job.cluster()
assert isinstance(cluster, DaskCluster)


async def _get_cluster_status(k8s_cluster, ns, cluster_name):
"""
Will loop infinitely in search of non-falsey cluster status.
Make sure there is a timeout on any test which calls this.
"""
while True:
cluster_status = k8s_cluster.kubectl(
"get",
"-n",
ns,
"daskcluster.kubernetes.dask.org",
cluster_name,
"-o",
"jsonpath='{.status.phase}'",
).strip("'")
if cluster_status:
return cluster_status
await asyncio.sleep(0.1)


@pytest.mark.timeout(180)
@pytest.mark.anyio
@pytest.mark.parametrize(
"cluster_name,expected_status",
[
("valid-name", "Created"),
((MAX_CLUSTER_NAME_LEN + 1) * "a", "Error"),
("invalid.chars.in.name", "Error"),
],
)
async def test_create_cluster_validates_name(
cluster_name, expected_status, k8s_cluster, kopf_runner, gen_cluster
):
with kopf_runner:
async with gen_cluster(cluster_name=cluster_name) as (_, ns):
actual_status = await _get_cluster_status(k8s_cluster, ns, cluster_name)
assert expected_status == actual_status
2 changes: 2 additions & 0 deletions dask_kubernetes/operator/kubecluster/kubecluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from rich.table import Table
from tornado.ioloop import IOLoop

from dask_kubernetes.common.objects import validate_cluster_name
from dask_kubernetes.exceptions import CrashLoopBackOffError, SchedulerStartupError
from dask_kubernetes.operator._objects import (
DaskAutoscaler,
Expand Down Expand Up @@ -258,6 +259,7 @@ def __init__(
name = name.format(
user=getpass.getuser(), uuid=str(uuid.uuid4())[:10], **os.environ
)
validate_cluster_name(name)
self._instances.add(self)
self._rich_spinner = Spinner("dots", speed=0.5)
self._startup_component_status: dict = {}
Expand Down
20 changes: 19 additions & 1 deletion dask_kubernetes/operator/kubecluster/tests/test_kubecluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dask.distributed import Client
from distributed.utils import TimeoutError

from dask_kubernetes.exceptions import SchedulerStartupError
from dask_kubernetes.constants import MAX_CLUSTER_NAME_LEN
from dask_kubernetes.exceptions import SchedulerStartupError, ValidationError
from dask_kubernetes.operator import KubeCluster, make_cluster_spec


Expand Down Expand Up @@ -202,3 +203,20 @@ def test_typo_resource_limits(ns):
},
namespace=ns,
)


@pytest.mark.parametrize(
"cluster_name",
[
(MAX_CLUSTER_NAME_LEN + 1) * "a",
"invalid.chars.in.name",
],
)
def test_invalid_cluster_name_fails(cluster_name, kopf_runner, docker_image, ns):
with kopf_runner:
with pytest.raises(ValidationError):
KubeCluster(
name=cluster_name,
namespace=ns,
image=docker_image,
)
Loading