Skip to content

Commit

Permalink
add pytorch test file
Browse files Browse the repository at this point in the history
Signed-off-by: ccchenjiahuan <chenjiahuan163@163.com>
  • Loading branch information
ccchenjiahuan committed Jul 2, 2022
1 parent 3e7aa2c commit baf8bf2
Show file tree
Hide file tree
Showing 5 changed files with 380 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,3 @@ func (pp *pytorchPlugin) GetMasterName() string {
func (pp *pytorchPlugin) GetWorkerName() string {
return pp.workerName
}

func (pp *pytorchPlugin) GetMpiArguments() []string {
return pp.pytorchArguments
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,367 @@
package pytorch

import (
"fmt"
"reflect"
"testing"

v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"volcano.sh/apis/pkg/apis/batch/v1alpha1"
pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface"
)

func TestPytorch(t *testing.T) {
plugins := make(map[string][]string)
plugins[PytorchPluginName] = []string{"--port=5000"}

testcases := []struct {
Name string
Job *v1alpha1.Job
Pod *v1.Pod
port int
envs []v1.EnvVar
}{
{
Name: "test pod without master",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Tasks: []v1alpha1.TaskSpec{
{
Name: "worker",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "worker",
},
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "worker",
},
},
},
},
port: DefaultPort,
envs: nil,
},
{
Name: "test pod without port",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Tasks: []v1alpha1.TaskSpec{
{
Name: "master",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
{
Name: "worker",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "master",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "master",
},
},
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "master",
},
},
},
},
port: DefaultPort,
envs: []v1.EnvVar{
{
Name: EnvMasterAddr,
Value:"test-pytorch-master-0.test-pytorch",
},
{
Name: EnvMasterPort,
Value: fmt.Sprintf("%v", DefaultPort),
},
{
Name: "WORLD_SIZE",
Value: fmt.Sprintf("%v", 2),
},
{
Name:"RANK",
Value: fmt.Sprintf("%v", 0),
},
},
},
{
Name: "test pod with port",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Tasks: []v1alpha1.TaskSpec{
{
Name: "master",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
{
Name: "worker",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "master",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "master",
},
},
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "master",
Ports: []v1.ContainerPort{
{
Name: "pytorchjob-port",
ContainerPort: 23456,
},
},
},
},
},
},
port: 23456,
envs: []v1.EnvVar{
{
Name: EnvMasterAddr,
Value:"test-pytorch-master-0.test-pytorch",
},
{
Name: EnvMasterPort,
Value: fmt.Sprintf("%v", DefaultPort),
},
{
Name: "WORLD_SIZE",
Value: fmt.Sprintf("%v", 2),
},
{
Name:"RANK",
Value: fmt.Sprintf("%v", 0),
},
},
},
{
Name: "test master pod env",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Tasks: []v1alpha1.TaskSpec{
{
Name: "master",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
{
Name: "worker",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "master",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "master",
},
},
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "master",
Ports: []v1.ContainerPort{
{
Name: "pytorchjob-port",
ContainerPort: 123,
},
},
},
},
},
},
port: 123,
envs: []v1.EnvVar{
{
Name: EnvMasterAddr,
Value:"test-pytorch-master-0.test-pytorch",
},
{
Name: EnvMasterPort,
Value: fmt.Sprintf("%v", DefaultPort),
},
{
Name: "WORLD_SIZE",
Value: fmt.Sprintf("%v", 2),
},
{
Name:"RANK",
Value: fmt.Sprintf("%v", 0),
},
},
},
{
Name: "test worker pod env",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Tasks: []v1alpha1.TaskSpec{
{
Name: "master",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
{
Name: "worker",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "worker",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "worker",
},
},
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "worker",
Ports: []v1.ContainerPort{
{
Name: "pytorchjob-port",
ContainerPort: 123,
},
},
},
},
},
},
port: 123,
envs: []v1.EnvVar{
{
Name: EnvMasterAddr,
Value:"test-pytorch-master-0.test-pytorch",
},
{
Name: EnvMasterPort,
Value: fmt.Sprintf("%v", DefaultPort),
},
{
Name: "WORLD_SIZE",
Value: fmt.Sprintf("%v", 2),
},
{
Name:"RANK",
Value: fmt.Sprintf("%v", 1),
},
},
},
{
Name: "test worker-2 pod env",
Job: &v1alpha1.Job{
ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
Spec: v1alpha1.JobSpec{
Tasks: []v1alpha1.TaskSpec{
{
Name: "master",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
{
Name: "worker",
Replicas: 1,
Template: v1.PodTemplateSpec{},
},
},
},
},
Pod: &v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "worker",
Annotations: map[string]string{
v1alpha1.TaskSpecKey: "worker",
},
},
Spec: v1.PodSpec{
Containers: []v1.Container{
{
Name: "worker",
Ports: []v1.ContainerPort{
{
Name: "pytorchjob-port",
ContainerPort: 123,
},
},
},
},
},
},
port: 123,
envs: []v1.EnvVar{
{
Name: EnvMasterAddr,
Value:"test-pytorch-master-0.test-pytorch",
},
{
Name: EnvMasterPort,
Value: fmt.Sprintf("%v", DefaultPort),
},
{
Name: "WORLD_SIZE",
Value: fmt.Sprintf("%v", 2),
},
{
Name:"RANK",
Value: fmt.Sprintf("%v", 2),
},
},
},
}

for index, testcase := range testcases {
t.Run(testcase.Name, func(t *testing.T) {
mp := New(pluginsinterface.PluginClientset{}, testcase.Job.Spec.Plugins[PytorchPluginName])
if err := mp.OnPodCreate(testcase.Pod, testcase.Job); err != nil {
t.Errorf("Case %d (%s): expect no error, but got error %v", index, testcase.Name, err)
}

if testcase.Pod.Spec.Containers[0].Ports == nil || testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort != int32(testcase.port) {
t.Errorf("Case %d (%s): wrong port, got %d, expected %v", index, testcase.Name, testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort, testcase.port)
}

if !reflect.DeepEqual(testcase.Pod.Spec.Containers[0].Env, testcase.envs) {
t.Errorf("Case %d (%s): wrong envs, got %v, expected %v", index, testcase.Name, testcase.Pod.Spec.Containers[0].Env, testcase.envs)
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import (
)

const (
// TFPluginName is the name of the plugin
TFPluginName = "tensorflow"
// DefaultPort defines default port for service
DefaultPort = 2222
// TFConfig defines environment variables for TF
Expand Down Expand Up @@ -67,7 +69,7 @@ func (tp *tensorflowPlugin) addFlags() {
}

func (tp *tensorflowPlugin) Name() string {
return "tensorflow"
return TFPluginName
}

func (tp *tensorflowPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error {
Expand Down
Loading

0 comments on commit baf8bf2

Please sign in to comment.