From 6048689ca61f318ba30ac3e0df8729918314035e Mon Sep 17 00:00:00 2001 From: bstadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:32:39 +0200 Subject: [PATCH] feat: Dask add pod template support (#374) * Add failing test Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> * WIP Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> * Improve test Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> * Refactor to use `ToK8sPodSpec` Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> * Fix linting Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> * Use `Always` restart policy for workers Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> * Add test which checks whether labels are propagated Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> * Replace `removeInterruptibleConfig` with `TaskExectuionMetadata` wrapper Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --------- Signed-off-by: Bernhard Stadlbauer <11799671+bstadlbauer@users.noreply.github.com> --- go.mod | 1 + go.sum | 4 +- .../pluginmachinery/flytek8s/pod_helper.go | 2 +- go/tasks/plugins/k8s/dask/dask.go | 280 +++++++++--------- go/tasks/plugins/k8s/dask/dask_test.go | 105 +++++-- 5 files changed, 220 insertions(+), 172 deletions(-) diff --git a/go.mod b/go.mod index aaefc2ca3..af4811c27 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/ray-project/kuberay/ray-operator v0.0.0-20220728052838-eaa75fa6707c github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.1 + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 golang.org/x/net v0.8.0 golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5 google.golang.org/api v0.76.0 diff --git a/go.sum b/go.sum index b83d5c84a..fe7a20883 100644 --- a/go.sum +++ b/go.sum @@ -735,6 +735,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 h1:MGwJjxBy0HJshjDNfLsYO8xppfqWlA5ZT9OhtUUhTNw= +golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -760,7 +762,7 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 661527c1d..c1db4f38b 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -61,7 +61,7 @@ func ApplyInterruptibleNodeSelectorRequirement(interruptible bool, affinity *v1. nst.MatchExpressions = append(nst.MatchExpressions, nodeSelectorRequirement) } } else { - affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{v1.NodeSelectorTerm{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}} + affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms = []v1.NodeSelectorTerm{{MatchExpressions: []v1.NodeSelectorRequirement{nodeSelectorRequirement}}} } } diff --git a/go/tasks/plugins/k8s/dask/dask.go b/go/tasks/plugins/k8s/dask/dask.go index f12423a64..cb0d9ec93 100755 --- a/go/tasks/plugins/k8s/dask/dask.go +++ b/go/tasks/plugins/k8s/dask/dask.go @@ -6,15 +6,12 @@ import ( "time" daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" - "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/logs" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" - "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" @@ -30,68 +27,50 @@ const ( KindDaskJob = "DaskJob" ) -type defaults struct { - Image string - JobRunnerContainer v1.Container - Resources *v1.ResourceRequirements - Env []v1.EnvVar - Annotations map[string]string - IsInterruptible bool +// Wraps a regular TaskExecutionMetadata and overrides the IsInterruptible method to always return false +// This is useful as the runner and the scheduler pods should never be interruptable +type nonInterruptibleTaskExecutionMetadata struct { + pluginsCore.TaskExecutionMetadata } -func getDefaults(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, taskTemplate core.TaskTemplate) (*defaults, error) { - executionMetadata := taskCtx.TaskExecutionMetadata() - - defaultContainerSpec := taskTemplate.GetContainer() - if defaultContainerSpec == nil { - return nil, errors.Errorf(errors.BadTaskSpecification, "task is missing a default container") - } - - defaultImage := defaultContainerSpec.GetImage() - if defaultImage == "" { - return nil, errors.Errorf(errors.BadTaskSpecification, "task is missing a default image") - } +func (n nonInterruptibleTaskExecutionMetadata) IsInterruptible() bool { + return false +} - var defaultEnvVars []v1.EnvVar - if taskTemplate.GetContainer().GetEnv() != nil { - for _, keyValuePair := range taskTemplate.GetContainer().GetEnv() { - defaultEnvVars = append(defaultEnvVars, v1.EnvVar{Name: keyValuePair.Key, Value: keyValuePair.Value}) - } - } +// A wrapper around a regular TaskExecutionContext allowing to inject a custom TaskExecutionMetadata which is +// non-interruptible +type nonInterruptibleTaskExecutionContext struct { + pluginsCore.TaskExecutionContext + metadata nonInterruptibleTaskExecutionMetadata +} - containerResources, err := flytek8s.ToK8sResourceRequirements(defaultContainerSpec.GetResources()) - if err != nil { - return nil, err - } +func (n nonInterruptibleTaskExecutionContext) TaskExecutionMetadata() pluginsCore.TaskExecutionMetadata { + return n.metadata +} - jobRunnerContainer := v1.Container{ - Name: "job-runner", - Image: defaultImage, - Args: defaultContainerSpec.GetArgs(), - Env: defaultEnvVars, - Resources: *containerResources, +func mergeMapInto(src map[string]string, dst map[string]string) { + for key, value := range src { + dst[key] = value } +} - templateParameters := template.Parameters{ - TaskExecMetadata: taskCtx.TaskExecutionMetadata(), - Inputs: taskCtx.InputReader(), - OutputPath: taskCtx.OutputWriter(), - Task: taskCtx.TaskReader(), +func getPrimaryContainer(spec *v1.PodSpec, primaryContainerName string) (*v1.Container, error) { + for _, container := range spec.Containers { + if container.Name == primaryContainerName { + return &container, nil + } } - if err = flytek8s.AddFlyteCustomizationsToContainer(ctx, templateParameters, - flytek8s.ResourceCustomizationModeMergeExistingResources, &jobRunnerContainer); err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, "primary container [%v] not found in pod spec", primaryContainerName) +} - return nil, err +func replacePrimaryContainer(spec *v1.PodSpec, primaryContainerName string, container v1.Container) error { + for i, c := range spec.Containers { + if c.Name == primaryContainerName { + spec.Containers[i] = container + return nil + } } - - return &defaults{ - Image: defaultImage, - JobRunnerContainer: jobRunnerContainer, - Resources: &jobRunnerContainer.Resources, - Env: defaultEnvVars, - Annotations: executionMetadata.GetAnnotations(), - IsInterruptible: executionMetadata.IsInterruptible(), - }, nil + return errors.Errorf(errors.BadTaskSpecification, "primary container [%v] not found in pod spec", primaryContainerName) } type daskResourceHandler struct { @@ -114,11 +93,6 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } else if taskTemplate == nil { return nil, errors.Errorf(errors.BadTaskSpecification, "nil task specification") } - defaults, err := getDefaults(ctx, taskCtx, *taskTemplate) - if err != nil { - return nil, err - } - clusterName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() daskJob := plugins.DaskJob{} err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &daskJob) @@ -126,38 +100,63 @@ func (p daskResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v], failed to unmarshal", taskTemplate.GetCustom()) } - workerSpec, err := createWorkerSpec(*daskJob.Workers, *defaults) + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, err } - schedulerSpec, err := createSchedulerSpec(*daskJob.Scheduler, clusterName, *defaults) + nonInterruptibleTaskMetadata := nonInterruptibleTaskExecutionMetadata{taskCtx.TaskExecutionMetadata()} + nonInterruptibleTaskCtx := nonInterruptibleTaskExecutionContext{taskCtx, nonInterruptibleTaskMetadata} + nonInterruptiblePodSpec, _, _, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) + if err != nil { + return nil, err + } + + // Add labels and annotations to objectMeta as they're not added by ToK8sPodSpec + mergeMapInto(taskCtx.TaskExecutionMetadata().GetAnnotations(), objectMeta.Annotations) + mergeMapInto(taskCtx.TaskExecutionMetadata().GetLabels(), objectMeta.Labels) + + workerSpec, err := createWorkerSpec(*daskJob.Workers, podSpec, primaryContainerName) + if err != nil { + return nil, err + } + + clusterName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + schedulerSpec, err := createSchedulerSpec(*daskJob.Scheduler, clusterName, nonInterruptiblePodSpec, primaryContainerName) + if err != nil { + return nil, err + } + + jobSpec, err := createJobSpec(*workerSpec, *schedulerSpec, nonInterruptiblePodSpec, primaryContainerName, objectMeta) if err != nil { return nil, err } - jobSpec := createJobSpec(*workerSpec, *schedulerSpec, *defaults) job := &daskAPI.DaskJob{ TypeMeta: metav1.TypeMeta{ Kind: KindDaskJob, APIVersion: daskAPI.SchemeGroupVersion.String(), }, - ObjectMeta: metav1.ObjectMeta{ - Name: "will-be-overridden", // Will be overridden by Flyte to `clusterName` - Annotations: defaults.Annotations, - }, - Spec: *jobSpec, + ObjectMeta: *objectMeta, + Spec: *jobSpec, } return job, nil } -func createWorkerSpec(cluster plugins.DaskWorkerGroup, defaults defaults) (*daskAPI.WorkerSpec, error) { - image := defaults.Image +func createWorkerSpec(cluster plugins.DaskWorkerGroup, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.WorkerSpec, error) { + workerPodSpec := podSpec.DeepCopy() + primaryContainer, err := getPrimaryContainer(workerPodSpec, primaryContainerName) + if err != nil { + return nil, err + } + primaryContainer.Name = "dask-worker" + + // Set custom image if present if cluster.GetImage() != "" { - image = cluster.GetImage() + primaryContainer.Image = cluster.GetImage() } - var err error - resources := defaults.Resources + // Set custom resources + resources := &primaryContainer.Resources clusterResources := cluster.GetResources() if len(clusterResources.Requests) >= 1 || len(clusterResources.Limits) >= 1 { resources, err = flytek8s.ToK8sResourceRequirements(cluster.GetResources()) @@ -168,16 +167,17 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, defaults defaults) (*dask if resources == nil { resources = &v1.ResourceRequirements{} } + primaryContainer.Resources = *resources + // Set custom args workerArgs := []string{ "dask-worker", "--name", "$(DASK_WORKER_NAME)", } - // If limits are set, append `--nthreads` and `--memory-limit` as per these docs: // https://kubernetes.dask.org/en/latest/kubecluster.html?#best-practices - if resources != nil && resources.Limits != nil { + if resources.Limits != nil { limits := resources.Limits if limits.Cpu() != nil { cpuCount := fmt.Sprintf("%v", limits.Cpu().Value()) @@ -188,78 +188,73 @@ func createWorkerSpec(cluster plugins.DaskWorkerGroup, defaults defaults) (*dask workerArgs = append(workerArgs, "--memory-limit", memory) } } + primaryContainer.Args = workerArgs - wokerSpec := v1.PodSpec{ - Affinity: &v1.Affinity{}, - Containers: []v1.Container{ - { - Name: "dask-worker", - Image: image, - ImagePullPolicy: v1.PullIfNotPresent, - Args: workerArgs, - Resources: *resources, - Env: defaults.Env, - }, - }, + err = replacePrimaryContainer(workerPodSpec, primaryContainerName, *primaryContainer) + if err != nil { + return nil, err } - if defaults.IsInterruptible { - wokerSpec.Tolerations = append(wokerSpec.Tolerations, config.GetK8sPluginConfig().InterruptibleTolerations...) - wokerSpec.NodeSelector = config.GetK8sPluginConfig().InterruptibleNodeSelector - } - flytek8s.ApplyInterruptibleNodeSelectorRequirement(defaults.IsInterruptible, wokerSpec.Affinity) + // All workers are created as k8s deployment and must have a restart policy of Always + workerPodSpec.RestartPolicy = v1.RestartPolicyAlways return &daskAPI.WorkerSpec{ Replicas: int(cluster.GetNumberOfWorkers()), - Spec: wokerSpec, + Spec: *workerPodSpec, }, nil } -func createSchedulerSpec(cluster plugins.DaskScheduler, clusterName string, defaults defaults) (*daskAPI.SchedulerSpec, error) { - schedulerImage := defaults.Image - if cluster.GetImage() != "" { - schedulerImage = cluster.GetImage() +func createSchedulerSpec(scheduler plugins.DaskScheduler, clusterName string, podSpec *v1.PodSpec, primaryContainerName string) (*daskAPI.SchedulerSpec, error) { + schedulerPodSpec := podSpec.DeepCopy() + primaryContainer, err := getPrimaryContainer(schedulerPodSpec, primaryContainerName) + if err != nil { + return nil, err } + primaryContainer.Name = "scheduler" - var err error - resources := defaults.Resources + // Override image if applicable + if scheduler.GetImage() != "" { + primaryContainer.Image = scheduler.GetImage() + } - clusterResources := cluster.GetResources() - if len(clusterResources.Requests) >= 1 || len(clusterResources.Limits) >= 1 { - resources, err = flytek8s.ToK8sResourceRequirements(cluster.GetResources()) + // Override resources if applicable + resources := &primaryContainer.Resources + schedulerResources := scheduler.GetResources() + if len(schedulerResources.Requests) >= 1 || len(schedulerResources.Limits) >= 1 { + resources, err = flytek8s.ToK8sResourceRequirements(scheduler.GetResources()) if err != nil { return nil, err } } - if resources == nil { - resources = &v1.ResourceRequirements{} + primaryContainer.Resources = *resources + + // Override args + primaryContainer.Args = []string{"dask-scheduler"} + + // Add ports + primaryContainer.Ports = []v1.ContainerPort{ + { + Name: "tcp-comm", + ContainerPort: 8786, + Protocol: "TCP", + }, + { + Name: "dashboard", + ContainerPort: 8787, + Protocol: "TCP", + }, + } + + schedulerPodSpec.RestartPolicy = v1.RestartPolicyAlways + + // Set primary container + err = replacePrimaryContainer(schedulerPodSpec, primaryContainerName, *primaryContainer) + if err != nil { + return nil, err } return &daskAPI.SchedulerSpec{ - Spec: v1.PodSpec{ - RestartPolicy: v1.RestartPolicyAlways, - Containers: []v1.Container{ - { - Name: "scheduler", - Image: schedulerImage, - Args: []string{"dask-scheduler"}, - Resources: *resources, - Env: defaults.Env, - Ports: []v1.ContainerPort{ - { - Name: "tcp-comm", - ContainerPort: 8786, - Protocol: "TCP", - }, - { - Name: "dashboard", - ContainerPort: 8787, - Protocol: "TCP", - }, - }, - }, - }, - }, + Spec: *schedulerPodSpec, Service: v1.ServiceSpec{ Type: v1.ServiceTypeNodePort, Selector: map[string]string{ @@ -284,26 +279,33 @@ func createSchedulerSpec(cluster plugins.DaskScheduler, clusterName string, defa }, nil } -func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.SchedulerSpec, defaults defaults) *daskAPI.DaskJobSpec { +func createJobSpec(workerSpec daskAPI.WorkerSpec, schedulerSpec daskAPI.SchedulerSpec, podSpec *v1.PodSpec, primaryContainerName string, objectMeta *metav1.ObjectMeta) (*daskAPI.DaskJobSpec, error) { + jobPodSpec := podSpec.DeepCopy() + jobPodSpec.RestartPolicy = v1.RestartPolicyNever + + primaryContainer, err := getPrimaryContainer(jobPodSpec, primaryContainerName) + if err != nil { + return nil, err + } + primaryContainer.Name = "job-runner" + + err = replacePrimaryContainer(jobPodSpec, primaryContainerName, *primaryContainer) + if err != nil { + return nil, err + } + return &daskAPI.DaskJobSpec{ Job: daskAPI.JobSpec{ - Spec: v1.PodSpec{ - RestartPolicy: v1.RestartPolicyNever, - Containers: []v1.Container{ - defaults.JobRunnerContainer, - }, - }, + Spec: *jobPodSpec, }, Cluster: daskAPI.DaskCluster{ - ObjectMeta: metav1.ObjectMeta{ - Annotations: defaults.Annotations, - }, + ObjectMeta: *objectMeta, Spec: daskAPI.DaskClusterSpec{ Worker: workerSpec, Scheduler: schedulerSpec, }, }, - } + }, nil } func (p daskResourceHandler) GetTaskPhase(ctx context.Context, pluginContext k8s.PluginContext, r client.Object) (pluginsCore.PhaseInfo, error) { diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index 3ea3abba3..966d2293d 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -27,9 +27,13 @@ import ( ) const ( - defaultTestImage = "image://" - testNWorkers = 10 - testTaskID = "some-acceptable-name" + defaultTestImage = "image://" + testNWorkers = 10 + testTaskID = "some-acceptable-name" + podTemplateName = "dask-dummy-pod-template-name" + defaultServiceAccountName = "default-service-account" + defaultNamespace = "default-namespace" + podTempaltePriorityClassName = "pod-template-priority-class-name" ) var ( @@ -40,6 +44,7 @@ var ( "execute-dask-task", } testAnnotations = map[string]string{"annotation-1": "val1"} + testLabels = map[string]string{"label-1": "val1"} testPlatformResources = v1.ResourceRequirements{ Requests: v1.ResourceList{ v1.ResourceCPU: resource.MustParse("4"), @@ -53,13 +58,23 @@ var ( Requests: testPlatformResources.Requests, Limits: testPlatformResources.Requests, } + podTemplate = &v1.PodTemplate{ + ObjectMeta: metav1.ObjectMeta{ + Name: podTemplateName, + }, + Template: v1.PodTemplateSpec{ + Spec: v1.PodSpec{ + PriorityClassName: podTempaltePriorityClassName, + }, + }, + } ) func dummyDaskJob(status daskAPI.JobStatus) *daskAPI.DaskJob { return &daskAPI.DaskJob{ ObjectMeta: metav1.ObjectMeta{ Name: "dask-job-name", - Namespace: "dask-namespace", + Namespace: defaultNamespace, }, Status: daskAPI.DaskJobStatus{ ClusterName: "dask-cluster-name", @@ -90,7 +105,7 @@ func dummpyDaskCustomObj(customImage string, resources *core.Resources) *plugins return &daskJob } -func dummyDaskTaskTemplate(customImage string, resources *core.Resources) *core.TaskTemplate { +func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTemplateName string) *core.TaskTemplate { // In a real usecase, resources will always be filled, but might be empty if resources == nil { resources = &core.Resources{ @@ -110,13 +125,17 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources) *core. if err != nil { panic(err) } - envVars := []*core.KeyValuePair{} + var envVars []*core.KeyValuePair for _, envVar := range testEnvVars { envVars = append(envVars, &core.KeyValuePair{Key: envVar.Name, Value: envVar.Value}) } + metadata := &core.TaskMetadata{ + PodTemplateName: podTemplateName, + } return &core.TaskTemplate{ - Id: &core.Identifier{Name: "test-build-resource"}, - Type: daskTaskType, + Id: &core.Identifier{Name: "test-build-resource"}, + Type: daskTaskType, + Metadata: metadata, Target: &core.TaskTemplate_Container{ Container: &core.Container{ Image: defaultTestImage, @@ -164,11 +183,13 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc taskExecutionMetadata := &mocks.TaskExecutionMetadata{} taskExecutionMetadata.OnGetTaskExecutionID().Return(tID) taskExecutionMetadata.OnGetAnnotations().Return(testAnnotations) - taskExecutionMetadata.OnGetLabels().Return(map[string]string{"label-1": "val1"}) + taskExecutionMetadata.OnGetLabels().Return(testLabels) taskExecutionMetadata.OnGetPlatformResources().Return(&testPlatformResources) taskExecutionMetadata.OnGetMaxAttempts().Return(uint32(1)) taskExecutionMetadata.OnIsInterruptible().Return(isInterruptible) taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) + taskExecutionMetadata.OnGetK8sServiceAccount().Return(defaultServiceAccountName) + taskExecutionMetadata.OnGetNamespace().Return(defaultNamespace) overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) taskExecutionMetadata.OnGetOverrides().Return(overrides) @@ -179,8 +200,8 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc func TestBuildResourceDaskHappyPath(t *testing.T) { daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate("", nil) - taskContext := dummyDaskTaskContext(taskTemplate, nil, false) + taskTemplate := dummyDaskTaskTemplate("", nil, "") + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -188,9 +209,8 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { assert.True(t, ok) var defaultTolerations []v1.Toleration - var defaultNodeSelector map[string]string - var defaultAffinity *v1.Affinity - defaultWorkerAffinity := v1.Affinity{ + defaultNodeSelector := map[string]string{} + defaultAffinity := &v1.Affinity{ NodeAffinity: nil, PodAffinity: nil, PodAntiAffinity: nil, @@ -199,6 +219,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { // Job jobSpec := daskJob.Spec.Job.Spec assert.Equal(t, testAnnotations, daskJob.ObjectMeta.GetAnnotations()) + assert.Equal(t, testLabels, daskJob.ObjectMeta.GetLabels()) assert.Equal(t, v1.RestartPolicyNever, jobSpec.RestartPolicy) assert.Equal(t, "job-runner", jobSpec.Containers[0].Name) assert.Equal(t, defaultTestImage, jobSpec.Containers[0].Image) @@ -208,11 +229,12 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { assert.Equal(t, defaultNodeSelector, jobSpec.NodeSelector) assert.Equal(t, defaultAffinity, jobSpec.Affinity) - // Flyte adds more environment variables to the driver + // Flyte adds more environment variables to the runner assert.Contains(t, jobSpec.Containers[0].Env, testEnvVars[0]) // Cluster assert.Equal(t, testAnnotations, daskJob.Spec.Cluster.ObjectMeta.GetAnnotations()) + assert.Equal(t, testLabels, daskJob.Spec.Cluster.ObjectMeta.GetLabels()) // Scheduler schedulerSpec := daskJob.Spec.Cluster.Spec.Scheduler.Spec @@ -233,7 +255,8 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { assert.Equal(t, defaultResources, schedulerSpec.Containers[0].Resources) assert.Equal(t, []string{"dask-scheduler"}, schedulerSpec.Containers[0].Args) assert.Equal(t, expectedPorts, schedulerSpec.Containers[0].Ports) - assert.Equal(t, testEnvVars, schedulerSpec.Containers[0].Env) + // Flyte adds more environment variables to the scheduler + assert.Contains(t, schedulerSpec.Containers[0].Env, testEnvVars[0]) assert.Equal(t, defaultTolerations, schedulerSpec.Tolerations) assert.Equal(t, defaultNodeSelector, schedulerSpec.NodeSelector) assert.Equal(t, defaultAffinity, schedulerSpec.Affinity) @@ -265,13 +288,13 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { workerSpec := daskJob.Spec.Cluster.Spec.Worker.Spec assert.Equal(t, testNWorkers, daskJob.Spec.Cluster.Spec.Worker.Replicas) assert.Equal(t, "dask-worker", workerSpec.Containers[0].Name) - assert.Equal(t, v1.PullIfNotPresent, workerSpec.Containers[0].ImagePullPolicy) assert.Equal(t, defaultTestImage, workerSpec.Containers[0].Image) assert.Equal(t, defaultResources, workerSpec.Containers[0].Resources) - assert.Equal(t, testEnvVars, workerSpec.Containers[0].Env) + // Flyte adds more environment variables to the worker + assert.Contains(t, workerSpec.Containers[0].Env, testEnvVars[0]) assert.Equal(t, defaultTolerations, workerSpec.Tolerations) assert.Equal(t, defaultNodeSelector, workerSpec.NodeSelector) - assert.Equal(t, &defaultWorkerAffinity, workerSpec.Affinity) + assert.Equal(t, defaultAffinity, workerSpec.Affinity) assert.Equal(t, []string{ "dask-worker", "--name", @@ -281,13 +304,14 @@ func TestBuildResourceDaskHappyPath(t *testing.T) { "--memory-limit", "1Gi", }, workerSpec.Containers[0].Args) + assert.Equal(t, workerSpec.RestartPolicy, v1.RestartPolicyAlways) } func TestBuildResourceDaskCustomImages(t *testing.T) { customImage := "customImage" daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate(customImage, nil) + taskTemplate := dummyDaskTaskTemplate(customImage, nil, "") taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -320,7 +344,7 @@ func TestBuildResourceDaskDefaultResoureRequirements(t *testing.T) { } daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate("", nil) + taskTemplate := dummyDaskTaskTemplate("", nil, "") taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -377,7 +401,7 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { } daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate("", &protobufResources) + taskTemplate := dummyDaskTaskTemplate("", &protobufResources, "") taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, false) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) @@ -403,8 +427,8 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) { } func TestBuildResourceDaskInterruptible(t *testing.T) { - var defaultNodeSelector map[string]string - var defaultAffinity *v1.Affinity + defaultNodeSelector := map[string]string{} + var defaultAffinity v1.Affinity var defaultTolerations []v1.Toleration interruptibleNodeSelector := map[string]string{ @@ -432,8 +456,8 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { daskResourceHandler := daskResourceHandler{} - taskTemplate := dummyDaskTaskTemplate("", nil) - taskContext := dummyDaskTaskContext(taskTemplate, nil, true) + taskTemplate := dummyDaskTaskTemplate("", nil, "") + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, true) r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) assert.Nil(t, err) assert.NotNil(t, r) @@ -444,13 +468,13 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { jobSpec := daskJob.Spec.Job.Spec assert.Equal(t, defaultTolerations, jobSpec.Tolerations) assert.Equal(t, defaultNodeSelector, jobSpec.NodeSelector) - assert.Equal(t, defaultAffinity, jobSpec.Affinity) + assert.Equal(t, &defaultAffinity, jobSpec.Affinity) - // Scheduler - should not bt interruptible + // Scheduler - should not be interruptible schedulerSpec := daskJob.Spec.Cluster.Spec.Scheduler.Spec assert.Equal(t, defaultTolerations, schedulerSpec.Tolerations) assert.Equal(t, defaultNodeSelector, schedulerSpec.NodeSelector) - assert.Equal(t, defaultAffinity, schedulerSpec.Affinity) + assert.Equal(t, &defaultAffinity, schedulerSpec.Affinity) // Default Workers - Should be interruptible workerSpec := daskJob.Spec.Cluster.Spec.Worker.Spec @@ -463,6 +487,25 @@ func TestBuildResourceDaskInterruptible(t *testing.T) { ) } +func TestBuildResouceDaskUsePodTemplate(t *testing.T) { + flytek8s.DefaultPodTemplateStore.Store(podTemplate) + daskResourceHandler := daskResourceHandler{} + taskTemplate := dummyDaskTaskTemplate("", nil, podTemplateName) + taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, false) + r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext) + assert.Nil(t, err) + assert.NotNil(t, r) + daskJob, ok := r.(*daskAPI.DaskJob) + assert.True(t, ok) + + assert.Equal(t, podTempaltePriorityClassName, daskJob.Spec.Job.Spec.PriorityClassName) + assert.Equal(t, podTempaltePriorityClassName, daskJob.Spec.Cluster.Spec.Scheduler.Spec.PriorityClassName) + assert.Equal(t, podTempaltePriorityClassName, daskJob.Spec.Cluster.Spec.Worker.Spec.PriorityClassName) + + // Cleanup + flytek8s.DefaultPodTemplateStore.Delete(podTemplate) +} + func TestGetPropertiesDask(t *testing.T) { daskResourceHandler := daskResourceHandler{} expected := k8s.PluginProperties{} @@ -478,7 +521,7 @@ func TestBuildIdentityResourceDask(t *testing.T) { }, } - taskTemplate := dummyDaskTaskTemplate("", nil) + taskTemplate := dummyDaskTaskTemplate("", nil, "") taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, false) identityResources, err := daskResourceHandler.BuildIdentityResource(context.TODO(), taskContext.TaskExecutionMetadata()) if err != nil { @@ -491,7 +534,7 @@ func TestGetTaskPhaseDask(t *testing.T) { daskResourceHandler := daskResourceHandler{} ctx := context.TODO() - taskTemplate := dummyDaskTaskTemplate("", nil) + taskTemplate := dummyDaskTaskTemplate("", nil, "") taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, false) taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(""))