Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ccchenjiahuan committed Jul 14, 2022
1 parent d141084 commit dccfb48
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 95 deletions.
71 changes: 0 additions & 71 deletions docs/design/distributed-framework-plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,6 @@ type pytorchPlugin struct {
workerName string
port int
}
// parse all arguments
func (pp *pytorchPlugin) addFlags() {
flagSet := flag.NewFlagSet(pp.Name(), flag.ContinueOnError)
flagSet.StringVar(&pp.masterName, "master", DefaultMaster, "name of master role task")
flagSet.StringVar(&pp.workerName, "worker", DefaultWorker, "name of worker role task")
flagSet.IntVar(&pp.port, "port", DefaultPort, "open port for containers")
if err := flagSet.Parse(pp.pytorchArguments); err != nil {
klog.Errorf("plugin %s flagset parse failed, err: %v", pp.Name(), err)
}
}
```

Then we patch pytorch-distributed-training related environment variables to container envs in method `OnPodCreate`.
Expand All @@ -234,67 +224,6 @@ The main environment variables are:
* `WORLD_SIZE`: total node number
* `RANK`: current node index

```go
func (pp *pytorchPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error {
taskType := helpers.GetTaskKey(pod)
masterIndex := helpers.GetTasklndexUnderJob(pp.masterName, job)
if masterIndex == -1 {
klog.Errorf("job %v doesn't have task %v", job.Name, pp.masterName)
for i, c := range pod.Spec.Containers {
pp.openContainerPort(&c, i, pod)
}
return nil
}
masterEnvVars := []v1.EnvVar{}
masterAddr := pp.generateMasterAddr(job.Spec.Tasks[masterIndex], job.Name)
masterEnvVars = append(masterEnvVars, v1.EnvVar{
Name: EnvMasterAddr,
Value: masterAddr,
}, v1.EnvVar{
Name: EnvMasterPort,
Value: fmt.Sprintf("%v", pp.port),
})
masterRank := 0
workerRank := 0
if taskType == pp.workerName {
index, err := strconv.Atoi(helpers.GetPodIndexUnderTask(pod))
if err != nil {
return err
}
workerRank = index + 1
}
totalReplicas := pp.getTotalReplicas(job)
for i, c := range pod.Spec.Containers {
pp.openContainerPort(&c, i, pod)
pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, masterEnvVars...)
pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{
Name: EnvWorldSize,
Value: strconv.Itoa(int(totalReplicas)),
})
if taskType == pp.workerName {
pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{
Name: EnvRank,
Value: strconv.Itoa(workerRank),
})
} else if taskType == pp.masterName {
pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{
Name: EnvRank,
Value: strconv.Itoa(masterRank),
})
}
}
return nil
}
```

#### Other Framework

Most of other frameworks is similar to Tensorflow. But the MPI framework is special. In most case, It needs a `hostfile`, e.g. :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ func (pp *pytorchPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error {
masterIndex := helpers.GetTasklndexUnderJob(pp.masterName, job)
if masterIndex == -1 {
klog.Errorf("job %v doesn't have task %v", job.Name, pp.masterName)
for i, c := range pod.Spec.Containers {
pp.openContainerPort(&c, i, pod)
}

return nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestPytorch(t *testing.T) {
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "worker",
Name: "test-pytorch-worker-0",
},
Spec: v1.PodSpec{
Containers: []v1.Container{
Expand All @@ -49,11 +49,11 @@ func TestPytorch(t *testing.T) {
},
},
},
port: DefaultPort,
port: -1,
envs: nil,
},
{
Name: "test pod without port",
Name: "test master pod without port",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Expand All @@ -73,7 +73,7 @@ func TestPytorch(t *testing.T) {
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "master",
Name: "test-pytorch-master-0",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "master",
},
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestPytorch(t *testing.T) {
},
},
{
Name: "test pod with port",
Name: "test master pod with port",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Expand All @@ -127,7 +127,7 @@ func TestPytorch(t *testing.T) {
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "master",
Name: "test-pytorch-master-0",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "master",
},
Expand All @@ -146,7 +146,7 @@ func TestPytorch(t *testing.T) {
},
},
},
port: 23456,
port: DefaultPort,
envs: []v1.EnvVar{
{
Name: EnvMasterAddr,
Expand Down Expand Up @@ -179,15 +179,15 @@ func TestPytorch(t *testing.T) {
},
{
Name: "worker",
Replicas: 1,
Replicas: 2,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "master",
Name: "test-pytorch-master-0",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "master",
},
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestPytorch(t *testing.T) {
},
{
Name: "WORLD_SIZE",
Value: fmt.Sprintf("%v", 2),
Value: fmt.Sprintf("%v", 3),
},
{
Name: "RANK",
Expand All @@ -227,7 +227,7 @@ func TestPytorch(t *testing.T) {
},
},
{
Name: "test worker pod env",
Name: "test worker-1 pod env",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Expand All @@ -239,15 +239,15 @@ func TestPytorch(t *testing.T) {
},
{
Name: "worker",
Replicas: 1,
Replicas: 2,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "worker",
Name: "test-pytorch-worker-0",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "worker",
},
Expand Down Expand Up @@ -278,7 +278,7 @@ func TestPytorch(t *testing.T) {
},
{
Name: "WORLD_SIZE",
Value: fmt.Sprintf("%v", 2),
Value: fmt.Sprintf("%v", 3),
},
{
Name: "RANK",
Expand All @@ -299,15 +299,15 @@ func TestPytorch(t *testing.T) {
},
{
Name: "worker",
Replicas: 1,
Replicas: 2,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "worker",
Name: "test-pytorch-worker-1",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "worker",
},
Expand Down Expand Up @@ -338,7 +338,7 @@ func TestPytorch(t *testing.T) {
},
{
Name: "WORLD_SIZE",
Value: fmt.Sprintf("%v", 2),
Value: fmt.Sprintf("%v", 3),
},
{
Name: "RANK",
Expand All @@ -355,8 +355,14 @@ func TestPytorch(t *testing.T) {
t.Errorf("Case %d (%s): expect no error, but got error %v", index, testcase.Name, err)
}

if testcase.Pod.Spec.Containers[0].Ports == nil || testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort != int32(testcase.port) {
t.Errorf("Case %d (%s): wrong port, got %d, expected %v", index, testcase.Name, testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort, testcase.port)
if testcase.port != -1 {
if testcase.Pod.Spec.Containers[0].Ports == nil || testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort != int32(testcase.port) {
t.Errorf("Case %d (%s): wrong port, got %d, expected %v", index, testcase.Name, testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort, testcase.port)
}
} else {
if testcase.Pod.Spec.Containers[0].Ports != nil {
t.Errorf("Case %d (%s): wrong port, got %d, expected empty", index, testcase.Name, testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort)
}
}

if !reflect.DeepEqual(testcase.Pod.Spec.Containers[0].Env, testcase.envs) {
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/jobseq/pytorch_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ var _ = Describe("Pytorch Plugin E2E Test", func() {

job := e2eutil.CreateJob(context, spec)
err := e2eutil.WaitJobPhases(context, job, []vcbatch.JobPhase{
vcbatch.Pending, vcbatch.Running, vcbatch.Completing, vcbatch.Completed})
vcbatch.Pending, vcbatch.Running, vcbatch.Completed})
Expect(err).NotTo(HaveOccurred())
})
})

0 comments on commit dccfb48

Please sign in to comment.