Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSoC] Add New Parameter in tune #2369

Merged
merged 13 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hack/gen-python-sdk/post_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _rewrite_helper(input_file, output_file, rewrite_rules):
if (output_file == "sdk/python/v1beta1/kubeflow/katib/__init__.py"):
lines.append("# Import Katib API client.\n")
lines.append("from kubeflow.katib.api.katib_client import KatibClient\n")
lines.append("# Import Katib report metrics functions")
lines.append("from kubeflow.katib.api.report_metrics import report_metrics")
lines.append("# Import Katib helper functions.\n")
lines.append("import kubeflow.katib.api.search as search\n")
lines.append("# Import Katib helper constants.\n")
Expand Down
4 changes: 2 additions & 2 deletions pkg/apis/controller/common/v1beta1/common_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ const (
CustomCollector CollectorKind = "Custom"

// When model training source code persists metrics into persistent layer
// directly, metricsCollector isn't in need, and its kind is "noneCollector"
NoneCollector CollectorKind = "None"
// directly, sidecar container isn't in need, and its kind is "pushCollector"
PushCollector CollectorKind = "Push"
Electronic-Waste marked this conversation as resolved.
Show resolved Hide resolved

MetricsVolume = "metrics-volume"
)
Expand Down
3 changes: 3 additions & 0 deletions pkg/controller.v1beta1/consts/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ const (
// resources list which can be used as trial template
ConfigTrialResources = "trial-resources"

// EnvTrialName is the env variable of Trial name
EnvTrialName = "KATIB_TRIAL_NAME"

// LabelExperimentName is the label of experiment name.
LabelExperimentName = "katib.kubeflow.org/experiment"
// LabelSuggestionName is the label of suggestion name.
Expand Down
2 changes: 1 addition & 1 deletion pkg/webhook/v1beta1/experiment/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ func (g *DefaultValidator) validateMetricsCollector(inst *experimentsv1beta1.Exp
}
// TODO(hougangliu): log warning message if some field will not be used for the metricsCollector kind
switch mcKind {
case commonapiv1beta1.NoneCollector, commonapiv1beta1.StdOutCollector:
case commonapiv1beta1.PushCollector, commonapiv1beta1.StdOutCollector:
return allErrs
case commonapiv1beta1.FileCollector:
if mcSpec.Source == nil || mcSpec.Source.FileSystemPath == nil ||
Expand Down
11 changes: 9 additions & 2 deletions pkg/webhook/v1beta1/pod/inject_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,22 @@ func (s *SidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error)
// Add Katib Trial labels to the Pod metadata.
mutatePodMetadata(mutatedPod, trial)

// Add env variables to the Pod's primary container.
// We add this function because of push-based metrics collection function `report_metrics` in Python SDK.
// Currently, we only pass the Trial name as env variable `KATIB_TRIAL_NAME` to the training container.
if err := mutatePodEnv(mutatedPod, trial); err != nil {
return nil, err
}

// Do the following mutation only for the Primary pod.
// If PrimaryPodLabel is not set we mutate all pods which are related to Trial job.
// Otherwise, mutate pod only with the appropriate labels.
if trial.Spec.PrimaryPodLabels != nil && !isPrimaryPod(pod.Labels, trial.Spec.PrimaryPodLabels) {
return mutatedPod, nil
}

// If Metrics Collector in None, skip the mutation.
if trial.Spec.MetricsCollector.Collector.Kind == common.NoneCollector {
// If Metrics Collector is Push, skip the mutation.
if trial.Spec.MetricsCollector.Collector.Kind == common.PushCollector {
return mutatedPod, nil
}

Expand Down
102 changes: 102 additions & 0 deletions pkg/webhook/v1beta1/pod/inject_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/onsi/gomega"
appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1"
Expand Down Expand Up @@ -1067,3 +1069,103 @@ func TestMutatePodMetadata(t *testing.T) {
}
}
}

func TestMutatePodEnv(t *testing.T) {
testcases := map[string]struct {
pod *v1.Pod
trial *trialsv1beta1.Trial
mutatedPod *v1.Pod
wantError error
}{
"Valid case for mutating Pod's env variable": {
pod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "training-container",
},
},
},
},
trial: &trialsv1beta1.Trial{
Spec: trialsv1beta1.TrialSpec{
PrimaryContainerName: "training-container",
},
},
mutatedPod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "training-container",
Env: []v1.EnvVar{
{
Name: consts.EnvTrialName,
ValueFrom: &v1.EnvVarSource{
FieldRef: &v1.ObjectFieldSelector{
FieldPath: fmt.Sprintf("metadata.labels['%s']", consts.LabelTrialName),
},
},
},
},
},
},
},
},
},
"Mismatch for Pod name and primaryContainerName in Trial": {
pod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "training-container",
},
},
},
},
trial: &trialsv1beta1.Trial{
Spec: trialsv1beta1.TrialSpec{
PrimaryContainerName: "training-containers",
},
},
mutatedPod: &v1.Pod{
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "training-container",
},
},
},
},
wantError: fmt.Errorf(
"Unable to find primary container %v in mutated pod containers %v",
"training-containers",
[]v1.Container{
{
Name: "training-container",
},
},
),
},
}

for name, testcase := range testcases {
t.Run(name, func(t *testing.T) {
err := mutatePodEnv(testcase.pod, testcase.trial)
// Compare error with expected error
if testcase.wantError != nil && err != nil {
if diff := cmp.Diff(testcase.wantError.Error(), err.Error()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
} else if testcase.wantError != nil || err != nil {
t.Errorf(
"Unexpected error (-want,+got):\n%s",
cmp.Diff(testcase.wantError, err, cmpopts.EquateErrors()),
)
}
// Compare Pod with expected pod after mutation
if diff := cmp.Diff(testcase.mutatedPod, testcase.pod); len(diff) != 0 {
t.Errorf("Unexpected mutated result (-want,+got):\n%s", diff)
}
})
}
}
27 changes: 27 additions & 0 deletions pkg/webhook/v1beta1/pod/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,33 @@ func mutatePodMetadata(pod *v1.Pod, trial *trialsv1beta1.Trial) {
pod.Labels = podLabels
}

func mutatePodEnv(pod *v1.Pod, trial *trialsv1beta1.Trial) error {
// Search for the primary container
index := getPrimaryContainerIndex(pod.Spec.Containers, trial.Spec.PrimaryContainerName)
if index >= 0 {
if pod.Spec.Containers[index].Env == nil {
pod.Spec.Containers[index].Env = []v1.EnvVar{}
}

// Pass env variable KATIB_TRIAL_NAME to the primary container using fieldPath
pod.Spec.Containers[index].Env = append(
pod.Spec.Containers[index].Env,
v1.EnvVar{
Name: consts.EnvTrialName,
ValueFrom: &v1.EnvVarSource{
FieldRef: &v1.ObjectFieldSelector{
FieldPath: fmt.Sprintf("metadata.labels['%s']", consts.LabelTrialName),
},
},
},
)
return nil
} else {
return fmt.Errorf("Unable to find primary container %v in mutated pod containers %v",
trial.Spec.PrimaryContainerName, pod.Spec.Containers)
}
}

func getSidecarContainerName(cKind common.CollectorKind) string {
if cKind == common.StdOutCollector || cKind == common.FileCollector {
return mccommon.MetricLoggerCollectorContainerName
Expand Down
10 changes: 10 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def tune(
retain_trials: bool = False,
packages_to_install: List[str] = None,
pip_index_url: str = "https://pypi.org/simple",
metrics_collector_config: Dict[str, Any] = {"kind": "StdOut"},
):
"""Create HyperParameter Tuning Katib Experiment from the objective function.

Expand Down Expand Up @@ -248,6 +249,9 @@ def tune(
to the base image packages. These packages are installed before
executing the objective function.
pip_index_url: The PyPI url from which to install Python packages.
metrics_collector_config: Specify the config of metrics collector,
for example, `metrics_collector_config = {"kind": "Push"}`.
Currently, we only support `StdOut` and `Push` metrics collector.

Raises:
ValueError: Function arguments have incorrect type or value.
Expand Down Expand Up @@ -380,6 +384,12 @@ def tune(
f"Incorrect value for env_per_trial: {env_per_trial}"
)

# Add metrics collector to the Katib Experiment.
# Up to now, We only support parameter `kind`, of which default value is `StdOut`, to specify the kind of metrics collector.
experiment.spec.metrics_collector = models.V1beta1MetricsCollectorSpec(
collector=models.V1beta1CollectorSpec(kind=metrics_collector_config["kind"])
)

# Create Trial specification.
trial_spec = client.V1Job(
api_version="batch/v1",
Expand Down