Skip to content

Commit 120a95b

Browse files
author
zhongzichao
authored
fix mpi jobmode to master worker (#937) (#939)
* fix mpi jobmode to master worker
1 parent 18937e7 commit 120a95b

File tree

2 files changed

+174
-12
lines changed

2 files changed

+174
-12
lines changed

pkg/apiserver/controller/job/create.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ func checkMemberRole(framework schema.Framework, roles map[schema.MemberRole]int
428428
var err error
429429
var jobMode string
430430
switch framework {
431-
case schema.FrameworkPaddle, schema.FrameworkTF, schema.FrameworkPytorch, schema.FrameworkMXNet, schema.FrameworkMPI:
431+
case schema.FrameworkPaddle, schema.FrameworkTF, schema.FrameworkPytorch, schema.FrameworkMXNet:
432432
if roles[schema.RolePServer] > 0 {
433433
// parameter server mode
434434
jobMode = schema.EnvJobModePS
@@ -447,9 +447,9 @@ func checkMemberRole(framework schema.Framework, roles map[schema.MemberRole]int
447447
if roles[schema.RoleDriver] < 1 {
448448
err = fmt.Errorf("spark application must be set role driver")
449449
}
450-
case schema.FrameworkRay:
450+
case schema.FrameworkRay, schema.FrameworkMPI:
451451
if roles[schema.RoleMaster] < 1 || roles[schema.RoleWorker] < 1 {
452-
err = fmt.Errorf("ray job must be set a master role and a worker role")
452+
err = fmt.Errorf("%s job must be set a master role and a worker role", framework)
453453
}
454454
case schema.FrameworkStandalone:
455455
if roles[schema.RoleWorker] != 1 {

pkg/apiserver/controller/job/create_test.go

+171-9
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,68 @@ import (
77

88
"github.com/PaddlePaddle/PaddleFlow/pkg/common/config"
99
"github.com/PaddlePaddle/PaddleFlow/pkg/common/logger"
10+
"github.com/PaddlePaddle/PaddleFlow/pkg/common/resources"
1011
"github.com/PaddlePaddle/PaddleFlow/pkg/common/schema"
12+
"github.com/PaddlePaddle/PaddleFlow/pkg/common/uuid"
13+
"github.com/PaddlePaddle/PaddleFlow/pkg/model"
14+
"github.com/PaddlePaddle/PaddleFlow/pkg/storage"
1115
"github.com/PaddlePaddle/PaddleFlow/pkg/storage/driver"
1216
)
1317

1418
const (
15-
mockRootUser = "root"
16-
mockCreatedJobName = "job-xxxx1"
17-
MockQueueName = "default-queue"
18-
MockQueueID = "default-queue"
19+
mockRootUser = "root"
20+
MockQueueName = "default-queue"
21+
MockQueueID = "default-queue"
22+
MockClusterName = "default-cluster"
1923
)
2024

25+
var clusterInfo = model.ClusterInfo{
26+
Name: MockClusterName,
27+
Description: "Description",
28+
Endpoint: "Endpoint",
29+
Source: "Source",
30+
ClusterType: schema.KubernetesType,
31+
Version: "1.16",
32+
Status: model.ClusterStatusOnLine,
33+
Credential: "credential",
34+
Setting: "Setting",
35+
}
36+
2137
func TestCreatePFJob(t *testing.T) {
2238
driver.InitMockDB()
2339
config.GlobalServerConfig = &config.ServerConfig{}
2440
config.GlobalServerConfig.Job.IsSingleCluster = true
2541

42+
err := storage.Cluster.CreateCluster(&model.ClusterInfo{
43+
Model: model.Model{
44+
ID: MockClusterName,
45+
},
46+
Name: MockClusterName,
47+
ClusterType: schema.KubernetesType,
48+
})
49+
assert.Equal(t, nil, err)
50+
maxRes, err := resources.NewResourceFromMap(map[string]string{
51+
resources.ResCPU: "10",
52+
resources.ResMemory: "20Gi",
53+
"nvidia.com/gpu": "500",
54+
})
55+
assert.Equal(t, nil, err)
56+
queueInfo := model.Queue{
57+
Model: model.Model{
58+
ID: MockQueueID,
59+
},
60+
Name: MockQueueName,
61+
Namespace: "default",
62+
MaxResources: maxRes,
63+
MinResources: maxRes,
64+
QuotaType: schema.TypeVolcanoCapabilityQuota,
65+
ClusterId: MockClusterName,
66+
ClusterName: MockClusterName,
67+
Status: "open",
68+
}
69+
err = storage.Queue.CreateQueue(&queueInfo)
70+
assert.NoError(t, err)
71+
2672
type args struct {
2773
ctx *logger.RequestContext
2874
req *CreateJobInfo
@@ -52,7 +98,7 @@ func TestCreatePFJob(t *testing.T) {
5298
},
5399
req: &CreateJobInfo{
54100
CommonJobInfo: CommonJobInfo{
55-
ID: mockCreatedJobName,
101+
ID: uuid.GenerateIDWithLength("job", 5),
56102
Name: "normal",
57103
Labels: map[string]string{},
58104
Annotations: map[string]string{},
@@ -67,21 +113,22 @@ func TestCreatePFJob(t *testing.T) {
67113
responseCode: 400,
68114
},
69115
{
70-
name: "create success request",
116+
name: "create mpijob success request",
71117
args: args{
72118
ctx: &logger.RequestContext{
73119
UserName: mockRootUser,
74120
},
75121
req: &CreateJobInfo{
76122
CommonJobInfo: CommonJobInfo{
77-
ID: mockCreatedJobName,
123+
ID: uuid.GenerateIDWithLength("job", 5),
78124
Name: "normal",
79125
Labels: map[string]string{},
80126
Annotations: map[string]string{},
81127
SchedulingPolicy: SchedulingPolicy{
82128
Queue: MockQueueName,
83129
},
84130
},
131+
Type: schema.TypeDistributed,
85132
Framework: schema.FrameworkMPI,
86133
Members: []MemberSpec{
87134
{
@@ -122,17 +169,132 @@ func TestCreatePFJob(t *testing.T) {
122169
wantErr: false,
123170
responseCode: 400,
124171
},
172+
{
173+
name: "the role[pserver] for framework mpi is not supported",
174+
args: args{
175+
ctx: &logger.RequestContext{
176+
UserName: mockRootUser,
177+
},
178+
req: &CreateJobInfo{
179+
CommonJobInfo: CommonJobInfo{
180+
ID: uuid.GenerateIDWithLength("job", 5),
181+
Name: "normal",
182+
Labels: map[string]string{},
183+
Annotations: map[string]string{},
184+
SchedulingPolicy: SchedulingPolicy{
185+
Queue: MockQueueName,
186+
},
187+
},
188+
Type: schema.TypeDistributed,
189+
Framework: schema.FrameworkMPI,
190+
Members: []MemberSpec{
191+
{
192+
Replicas: 1,
193+
Role: string(schema.RolePServer),
194+
CommonJobInfo: CommonJobInfo{
195+
Name: "normal",
196+
Labels: map[string]string{},
197+
Annotations: map[string]string{},
198+
SchedulingPolicy: SchedulingPolicy{
199+
Queue: MockQueueName,
200+
},
201+
},
202+
JobSpec: JobSpec{
203+
Image: "iregistry.baidu-int.com/bmlc/trainingjob:0.20.0-tf2.3.0-torch1.6.0-mxnet1.5.0-py3.7-cpu",
204+
Command: "sleep 20",
205+
},
206+
},
207+
{
208+
Replicas: 1,
209+
Role: string(schema.RoleWorker),
210+
CommonJobInfo: CommonJobInfo{
211+
Name: "normal",
212+
Labels: map[string]string{},
213+
Annotations: map[string]string{},
214+
SchedulingPolicy: SchedulingPolicy{
215+
Queue: MockQueueName,
216+
},
217+
},
218+
JobSpec: JobSpec{
219+
Image: "iregistry.baidu-int.com/bmlc/trainingjob:0.20.0-tf2.3.0-torch1.6.0-mxnet1.5.0-py3.7-cpu",
220+
Command: "sleep 20",
221+
},
222+
},
223+
},
224+
},
225+
},
226+
wantErr: true,
227+
responseCode: 400,
228+
},
229+
{
230+
name: "mpi job must be set a master role and a worker role",
231+
args: args{
232+
ctx: &logger.RequestContext{
233+
UserName: mockRootUser,
234+
},
235+
req: &CreateJobInfo{
236+
CommonJobInfo: CommonJobInfo{
237+
ID: uuid.GenerateIDWithLength("job", 5),
238+
Name: "normal",
239+
Labels: map[string]string{},
240+
Annotations: map[string]string{},
241+
SchedulingPolicy: SchedulingPolicy{
242+
Queue: MockQueueName,
243+
},
244+
},
245+
Type: schema.TypeDistributed,
246+
Framework: schema.FrameworkMPI,
247+
Members: []MemberSpec{
248+
{
249+
Replicas: 1,
250+
Role: string(schema.RoleWorker),
251+
CommonJobInfo: CommonJobInfo{
252+
Name: "normal",
253+
Labels: map[string]string{},
254+
Annotations: map[string]string{},
255+
SchedulingPolicy: SchedulingPolicy{
256+
Queue: MockQueueName,
257+
},
258+
},
259+
JobSpec: JobSpec{
260+
Image: "iregistry.baidu-int.com/bmlc/trainingjob:0.20.0-tf2.3.0-torch1.6.0-mxnet1.5.0-py3.7-cpu",
261+
Command: "sleep 20",
262+
},
263+
},
264+
{
265+
Replicas: 1,
266+
Role: string(schema.RoleWorker),
267+
CommonJobInfo: CommonJobInfo{
268+
Name: "normal",
269+
Labels: map[string]string{},
270+
Annotations: map[string]string{},
271+
SchedulingPolicy: SchedulingPolicy{
272+
Queue: MockQueueName,
273+
},
274+
},
275+
JobSpec: JobSpec{
276+
Image: "iregistry.baidu-int.com/bmlc/trainingjob:0.20.0-tf2.3.0-torch1.6.0-mxnet1.5.0-py3.7-cpu",
277+
Command: "sleep 20",
278+
},
279+
},
280+
},
281+
},
282+
},
283+
wantErr: true,
284+
responseCode: 400,
285+
},
125286
}
126287

127288
for _, tt := range tests {
128289
t.Run(tt.name, func(t *testing.T) {
129290
t.Logf("name=%s args=[%#v], wantError=%v", tt.name, tt.args, tt.wantErr)
130291
res, err := CreatePFJob(tt.args.ctx, tt.args.req)
131-
t.Logf("case[%s] create single job, response=%+v", tt.name, res)
292+
t.Logf("case[%s] create job, response=%+v", tt.name, res)
132293
if tt.wantErr {
133294
assert.Error(t, err)
295+
t.Logf("name=%s err: %v", tt.name, err)
134296
} else {
135-
assert.Contains(t, err.Error(), "record not found")
297+
t.Logf("response: %+v", res)
136298
}
137299
})
138300
}

0 commit comments

Comments
 (0)