From 7711df2cebaaa6a2dc8d7de2149859eed5ba0cc2 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 31 Jan 2024 20:03:58 +0800 Subject: [PATCH] Agent ClientSet (#4718) * v1 Signed-off-by: Future-Outlier * kevin wip Signed-off-by: Kevin Su * add mockery AsyncAgentClient Signed-off-by: Future-Outlier * improve error message Signed-off-by: Future-Outlier * improve error message Signed-off-by: Future-Outlier * improve error message Signed-off-by: Future-Outlier * need to use mockery AsyncAgentClient FIrst Signed-off-by: Future-Outlier * set config TestInitializeAgentRegistry Signed-off-by: Future-Outlier * push change Signed-off-by: Future-Outlier * make generate Signed-off-by: Kevin Su * update tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su --------- Signed-off-by: Future-Outlier Signed-off-by: Kevin Su Signed-off-by: Future-Outlier Co-authored-by: Future-Outlier Co-authored-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/client.go | 159 +++++++++++ .../tasks/plugins/webapi/agent/client_test.go | 18 ++ .../plugins/webapi/agent/integration_test.go | 168 +++--------- .../agent/mocks/AsyncAgentServiceClient.go | 258 ++++++++++++++++++ .../go/tasks/plugins/webapi/agent/plugin.go | 199 ++------------ .../tasks/plugins/webapi/agent/plugin_test.go | 103 ++----- 6 files changed, 522 insertions(+), 383 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/webapi/agent/client.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/agent/client_test.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go new file mode 100644 index 0000000000..b118f64596 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -0,0 +1,159 @@ +package agent + +import ( + "context" + "crypto/x509" + "fmt" + + "golang.org/x/exp/maps" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flytestdlib/config" + "github.com/flyteorg/flyte/flytestdlib/logger" +) + +// ClientSet contains the clients exposed to communicate with various agent services. +type ClientSet struct { + agentClients map[string]service.AsyncAgentServiceClient // map[endpoint] => client + agentMetadataClients map[string]service.AgentMetadataServiceClient // map[endpoint] => client +} + +func getGrpcConnection(ctx context.Context, agent *Agent) (*grpc.ClientConn, error) { + var opts []grpc.DialOption + + if agent.Insecure { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } else { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + creds := credentials.NewClientTLSFromCert(pool, "") + opts = append(opts, grpc.WithTransportCredentials(creds)) + } + + if len(agent.DefaultServiceConfig) != 0 { + opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig)) + } + + var err error + conn, err := grpc.Dial(agent.Endpoint, opts...) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) + } + }() + }() + + return conn, nil +} + +func getFinalTimeout(operation string, agent *Agent) config.Duration { + if t, exists := agent.Timeouts[operation]; exists { + return t + } + + return agent.DefaultTimeout +} + +func getFinalContext(ctx context.Context, operation string, agent *Agent) (context.Context, context.CancelFunc) { + timeout := getFinalTimeout(operation, agent).Duration + if timeout == 0 { + return ctx, func() {} + } + + return context.WithTimeout(ctx, timeout) +} + +func initializeAgentRegistry(cs *ClientSet) (map[string]*Agent, error) { + agentRegistry := make(map[string]*Agent) + cfg := GetConfig() + var agentDeployments []*Agent + + // Ensure that the old configuration is backward compatible + for taskType, agentID := range cfg.AgentForTaskTypes { + agentRegistry[taskType] = cfg.Agents[agentID] + } + + if len(cfg.DefaultAgent.Endpoint) != 0 { + agentDeployments = append(agentDeployments, &cfg.DefaultAgent) + } + agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...) + for _, agentDeployment := range agentDeployments { + client := cs.agentMetadataClients[agentDeployment.Endpoint] + + finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment) + defer cancel() + + res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{}) + if err != nil { + grpcStatus, ok := status.FromError(err) + if grpcStatus.Code() == codes.Unimplemented { + // we should not panic here, as we want to continue to support old agent settings + logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment) + continue + } + + if !ok { + return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err) + } + + return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err) + } + + agents := res.GetAgents() + for _, agent := range agents { + supportedTaskTypes := agent.SupportedTaskTypes + for _, supportedTaskType := range supportedTaskTypes { + agentRegistry[supportedTaskType] = agentDeployment + } + } + } + + return agentRegistry, nil +} + +func initializeClients(ctx context.Context) (*ClientSet, error) { + agentClients := make(map[string]service.AsyncAgentServiceClient) + agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) + + var agentDeployments []*Agent + cfg := GetConfig() + + if len(cfg.DefaultAgent.Endpoint) != 0 { + agentDeployments = append(agentDeployments, &cfg.DefaultAgent) + } + agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...) + for _, agentDeployment := range agentDeployments { + conn, err := getGrpcConnection(ctx, agentDeployment) + if err != nil { + return nil, err + } + agentClients[agentDeployment.Endpoint] = service.NewAsyncAgentServiceClient(conn) + agentMetadataClients[agentDeployment.Endpoint] = service.NewAgentMetadataServiceClient(conn) + } + + return &ClientSet{ + agentClients: agentClients, + agentMetadataClients: agentMetadataClients, + }, nil +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go new file mode 100644 index 0000000000..d68811d037 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -0,0 +1,18 @@ +package agent + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInitializeClients(t *testing.T) { + cfg := defaultConfig + ctx := context.Background() + err := SetConfig(&cfg) + assert.NoError(t, err) + cs, err := initializeClients(ctx) + assert.NoError(t, err) + assert.NotNil(t, cs) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index a2b2135591..fe3b45b881 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "google.golang.org/grpc" "k8s.io/apimachinery/pkg/util/rand" "k8s.io/utils/strings/slices" @@ -25,6 +24,7 @@ import ( pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi" + agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/flyteorg/flyte/flyteplugins/tests" "github.com/flyteorg/flyte/flytestdlib/contextutils" "github.com/flyteorg/flyte/flytestdlib/promutils" @@ -33,101 +33,6 @@ import ( "github.com/flyteorg/flyte/flytestdlib/utils" ) -type MockPlugin struct { - Plugin -} - -type MockAsyncTask struct { -} - -func (m *MockAsyncTask) GetTaskMetrics(ctx context.Context, in *admin.GetTaskMetricsRequest, opts ...grpc.CallOption) (*admin.GetTaskMetricsResponse, error) { - panic("not implemented") -} - -func (m *MockAsyncTask) GetTaskLogs(ctx context.Context, in *admin.GetTaskLogsRequest, opts ...grpc.CallOption) (*admin.GetTaskLogsResponse, error) { - panic("not implemented") -} - -type MockSyncTask struct { -} - -func (m *MockSyncTask) GetTaskMetrics(ctx context.Context, in *admin.GetTaskMetricsRequest, opts ...grpc.CallOption) (*admin.GetTaskMetricsResponse, error) { - panic("not implemented") -} - -func (m *MockSyncTask) GetTaskLogs(ctx context.Context, in *admin.GetTaskLogsRequest, opts ...grpc.CallOption) (*admin.GetTaskLogsResponse, error) { - panic("not implemented") -} - -func (m *MockAsyncTask) CreateTask(_ context.Context, createTaskRequest *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) { - expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"} - if slices.Equal(createTaskRequest.Template.GetContainer().Args, expectedArgs) { - return nil, fmt.Errorf("args not as expected") - } - return &admin.CreateTaskResponse{ - Res: &admin.CreateTaskResponse_ResourceMeta{ - ResourceMeta: []byte{1, 2, 3, 4}, - }}, nil -} - -func (m *MockAsyncTask) GetTask(_ context.Context, req *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) { - if req.GetTaskType() == "bigquery_query_job_task" { - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{ - Literals: map[string]*flyteIdlCore.Literal{ - "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), - }, - }}}, nil - } - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil -} - -func (m *MockAsyncTask) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) { - return &admin.DeleteTaskResponse{}, nil -} - -func (m *MockSyncTask) CreateTask(_ context.Context, createTaskRequest *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) { - return &admin.CreateTaskResponse{ - Res: &admin.CreateTaskResponse_Resource{ - Resource: &admin.Resource{ - State: admin.State_SUCCEEDED, - Outputs: &flyteIdlCore.LiteralMap{ - Literals: map[string]*flyteIdlCore.Literal{}, - }, - Message: "Sync task finished", - LogLinks: []*flyteIdlCore.TaskLog{{Uri: "http://localhost:3000/log", Name: "Log Link"}}, - }, - }, - }, nil - -} - -func (m *MockSyncTask) GetTask(_ context.Context, req *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) { - if req.GetTaskType() == "fake_task" { - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{ - Literals: map[string]*flyteIdlCore.Literal{ - "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), - }, - }}}, nil - } - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil -} - -func (m *MockSyncTask) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) { - return &admin.DeleteTaskResponse{}, nil -} - -func mockAsyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - return &MockAsyncTask{}, nil -} - -func mockSyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - return &MockSyncTask{}, nil -} - -func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - return nil, fmt.Errorf("error") -} - func TestEndToEnd(t *testing.T) { iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { return nil @@ -137,6 +42,7 @@ func TestEndToEnd(t *testing.T) { cfg.WebAPI.ResourceQuotas = map[core.ResourceNamespace]int{} cfg.WebAPI.Caching.Workers = 1 cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + cfg.DefaultAgent.Endpoint = "localhost:8000" err := SetConfig(&cfg) assert.NoError(t, err) @@ -158,10 +64,10 @@ func TestEndToEnd(t *testing.T) { inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) template := flyteIdlCore.TaskTemplate{ - Type: "bigquery_query_job_task", + Type: "spark", Custom: st, Target: &flyteIdlCore.TaskTemplate_Container{ - Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "{{.outputPrefix}}"}}, + Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"}}, }, } basePrefix := storage.DataReference("fake://bucket/prefix/") @@ -174,20 +80,20 @@ func TestEndToEnd(t *testing.T) { phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) assert.Equal(t, true, phase.Phase().IsSuccess()) - template.Type = "spark_job" + template.Type = "spark" phase = tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) assert.Equal(t, true, phase.Phase().IsSuccess()) - }) t.Run("failed to create a job", func(t *testing.T) { agentPlugin := newMockAgentPlugin() agentPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &MockPlugin{ - Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - getClient: mockGetBadAsyncClientFunc, + return Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + cs: &ClientSet{ + agentClients: map[string]service.AsyncAgentServiceClient{}, + agentMetadataClients: map[string]service.AgentMetadataServiceClient{}, }, }, nil } @@ -319,31 +225,41 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { } func newMockAgentPlugin() webapi.PluginEntry { - return webapi.PluginEntry{ - ID: "agent-service", - SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job", "api_task"}, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &MockPlugin{ - Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - getClient: mockAsyncTaskClientFunc, - }, - }, nil - }, - } -} -func newMockSyncAgentPlugin() webapi.PluginEntry { + agentClient := new(agentMocks.AsyncAgentServiceClient) + + mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool { + expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"} + return slices.Equal(request.Template.GetContainer().Args, expectedArgs) + }) + agentClient.On("CreateTask", mock.Anything, mockCreateRequestMatcher).Return(&admin.CreateTaskResponse{ + Res: &admin.CreateTaskResponse_ResourceMeta{ + ResourceMeta: []byte{1, 2, 3, 4}, + }}, nil) + + mockGetRequestMatcher := mock.MatchedBy(func(request *admin.GetTaskRequest) bool { + return request.GetTaskType() == "spark" + }) + agentClient.On("GetTask", mock.Anything, mockGetRequestMatcher).Return( + &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil) + + agentClient.On("DeleteTask", mock.Anything, mock.Anything).Return( + &admin.DeleteTaskResponse{}, nil) + + cfg := defaultConfig + cfg.DefaultAgent.Endpoint = "localhost:8000" + return webapi.PluginEntry{ ID: "agent-service", - SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job", "api_task"}, + SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark", "api_task"}, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &MockPlugin{ - Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - getClient: mockSyncTaskClientFunc, + return Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: &cfg, + cs: &ClientSet{ + agentClients: map[string]service.AsyncAgentServiceClient{ + "localhost:8000": agentClient, + }, }, }, nil }, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go b/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go new file mode 100644 index 0000000000..f11ef1adfe --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go @@ -0,0 +1,258 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + + grpc "google.golang.org/grpc" + + mock "github.com/stretchr/testify/mock" +) + +// AsyncAgentServiceClient is an autogenerated mock type for the AsyncAgentServiceClient type +type AsyncAgentServiceClient struct { + mock.Mock +} + +type AsyncAgentServiceClient_CreateTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_CreateTask) Return(_a0 *admin.CreateTaskResponse, _a1 error) *AsyncAgentServiceClient_CreateTask { + return &AsyncAgentServiceClient_CreateTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnCreateTask(ctx context.Context, in *admin.CreateTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_CreateTask { + c_call := _m.On("CreateTask", ctx, in, opts) + return &AsyncAgentServiceClient_CreateTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnCreateTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_CreateTask { + c_call := _m.On("CreateTask", matchers...) + return &AsyncAgentServiceClient_CreateTask{Call: c_call} +} + +// CreateTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) CreateTask(ctx context.Context, in *admin.CreateTaskRequest, opts ...grpc.CallOption) (*admin.CreateTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.CreateTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.CreateTaskRequest, ...grpc.CallOption) *admin.CreateTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.CreateTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.CreateTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_DeleteTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_DeleteTask) Return(_a0 *admin.DeleteTaskResponse, _a1 error) *AsyncAgentServiceClient_DeleteTask { + return &AsyncAgentServiceClient_DeleteTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnDeleteTask(ctx context.Context, in *admin.DeleteTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_DeleteTask { + c_call := _m.On("DeleteTask", ctx, in, opts) + return &AsyncAgentServiceClient_DeleteTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnDeleteTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_DeleteTask { + c_call := _m.On("DeleteTask", matchers...) + return &AsyncAgentServiceClient_DeleteTask{Call: c_call} +} + +// DeleteTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) DeleteTask(ctx context.Context, in *admin.DeleteTaskRequest, opts ...grpc.CallOption) (*admin.DeleteTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.DeleteTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.DeleteTaskRequest, ...grpc.CallOption) *admin.DeleteTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.DeleteTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.DeleteTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_GetTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_GetTask) Return(_a0 *admin.GetTaskResponse, _a1 error) *AsyncAgentServiceClient_GetTask { + return &AsyncAgentServiceClient_GetTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnGetTask(ctx context.Context, in *admin.GetTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_GetTask { + c_call := _m.On("GetTask", ctx, in, opts) + return &AsyncAgentServiceClient_GetTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_GetTask { + c_call := _m.On("GetTask", matchers...) + return &AsyncAgentServiceClient_GetTask{Call: c_call} +} + +// GetTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) GetTask(ctx context.Context, in *admin.GetTaskRequest, opts ...grpc.CallOption) (*admin.GetTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.GetTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.GetTaskRequest, ...grpc.CallOption) *admin.GetTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.GetTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.GetTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_GetTaskLogs struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_GetTaskLogs) Return(_a0 *admin.GetTaskLogsResponse, _a1 error) *AsyncAgentServiceClient_GetTaskLogs { + return &AsyncAgentServiceClient_GetTaskLogs{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskLogs(ctx context.Context, in *admin.GetTaskLogsRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_GetTaskLogs { + c_call := _m.On("GetTaskLogs", ctx, in, opts) + return &AsyncAgentServiceClient_GetTaskLogs{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskLogsMatch(matchers ...interface{}) *AsyncAgentServiceClient_GetTaskLogs { + c_call := _m.On("GetTaskLogs", matchers...) + return &AsyncAgentServiceClient_GetTaskLogs{Call: c_call} +} + +// GetTaskLogs provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) GetTaskLogs(ctx context.Context, in *admin.GetTaskLogsRequest, opts ...grpc.CallOption) (*admin.GetTaskLogsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.GetTaskLogsResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.GetTaskLogsRequest, ...grpc.CallOption) *admin.GetTaskLogsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.GetTaskLogsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.GetTaskLogsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_GetTaskMetrics struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_GetTaskMetrics) Return(_a0 *admin.GetTaskMetricsResponse, _a1 error) *AsyncAgentServiceClient_GetTaskMetrics { + return &AsyncAgentServiceClient_GetTaskMetrics{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskMetrics(ctx context.Context, in *admin.GetTaskMetricsRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_GetTaskMetrics { + c_call := _m.On("GetTaskMetrics", ctx, in, opts) + return &AsyncAgentServiceClient_GetTaskMetrics{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskMetricsMatch(matchers ...interface{}) *AsyncAgentServiceClient_GetTaskMetrics { + c_call := _m.On("GetTaskMetrics", matchers...) + return &AsyncAgentServiceClient_GetTaskMetrics{Call: c_call} +} + +// GetTaskMetrics provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) GetTaskMetrics(ctx context.Context, in *admin.GetTaskMetricsRequest, opts ...grpc.CallOption) (*admin.GetTaskMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.GetTaskMetricsResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.GetTaskMetricsRequest, ...grpc.CallOption) *admin.GetTaskMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.GetTaskMetricsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.GetTaskMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 7a9f92caaa..13115f89b4 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -2,22 +2,14 @@ package agent import ( "context" - "crypto/x509" "encoding/gob" "fmt" "time" "golang.org/x/exp/maps" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/status" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" @@ -25,20 +17,15 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi" - "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" ) -type GetClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) -type GetAgentMetadataClientFunc func(ctx context.Context, agent *Agent, connCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) - type Plugin struct { - metricScope promutils.Scope - cfg *Config - getClient GetClientFunc - connectionCache map[*Agent]*grpc.ClientConn - agentRegistry map[string]*Agent // map[taskType] => Agent + metricScope promutils.Scope + cfg *Config + cs *ClientSet + agentRegistry map[string]*Agent // map[taskType] => Agent } type ResourceWrapper struct { @@ -97,9 +84,9 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR agent := getFinalAgent(taskTemplate.Type, p.cfg, p.agentRegistry) - client, err := p.getClient(ctx, agent, p.connectionCache) - if err != nil { - return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) + client := p.cs.agentClients[agent.Endpoint] + if client == nil { + return nil, nil, fmt.Errorf("default agent is not connected, please check if endpoint:[%v] is up and running", agent.Endpoint) } finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) @@ -141,17 +128,9 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry) - if err != nil { - return nil, fmt.Errorf("failed to find agent with error: %v", err) - } - - client, err := p.getClient(ctx, agent, p.connectionCache) - if err != nil { - return nil, fmt.Errorf("failed to connect to agent with error: %v", err) - } + client := p.cs.agentClients[agent.Endpoint] finalCtx, cancel := getFinalContext(ctx, "GetTask", agent) defer cancel() @@ -174,18 +153,13 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry) - client, err := p.getClient(ctx, agent, p.connectionCache) - if err != nil { - return fmt.Errorf("failed to connect to agent with error: %v", err) - } - + client := p.cs.agentClients[agent.Endpoint] finalCtx, cancel := getFinalContext(ctx, "DeleteTask", agent) defer cancel() - _, err = client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) + _, err := client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) return err } @@ -271,71 +245,6 @@ func getFinalAgent(taskType string, cfg *Config, agentRegistry map[string]*Agent return &cfg.DefaultAgent } -func getGrpcConnection(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (*grpc.ClientConn, error) { - conn, ok := connectionCache[agent] - if ok { - return conn, nil - } - var opts []grpc.DialOption - - if agent.Insecure { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } else { - pool, err := x509.SystemCertPool() - if err != nil { - return nil, err - } - - creds := credentials.NewClientTLSFromCert(pool, "") - opts = append(opts, grpc.WithTransportCredentials(creds)) - } - - if len(agent.DefaultServiceConfig) != 0 { - opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig)) - } - - var err error - conn, err = grpc.Dial(agent.Endpoint, opts...) - if err != nil { - return nil, err - } - connectionCache[agent] = conn - defer func() { - if err != nil { - if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) - } - return - } - go func() { - <-ctx.Done() - if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) - } - }() - }() - - return conn, nil -} - -func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - conn, err := getGrpcConnection(ctx, agent, connectionCache) - if err != nil { - return nil, err - } - - return service.NewAsyncAgentServiceClient(conn), nil -} - -func getAgentMetadataClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) { - conn, err := getGrpcConnection(ctx, agent, connectionCache) - if err != nil { - return nil, err - } - - return service.NewAgentMetadataServiceClient(conn), nil -} - func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata { taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID() return admin.TaskExecutionMetadata{ @@ -348,82 +257,19 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } } -func getFinalTimeout(operation string, agent *Agent) config.Duration { - if t, exists := agent.Timeouts[operation]; exists { - return t - } - - return agent.DefaultTimeout -} - -func getFinalContext(ctx context.Context, operation string, agent *Agent) (context.Context, context.CancelFunc) { - timeout := getFinalTimeout(operation, agent).Duration - if timeout == 0 { - return ctx, func() {} - } - - return context.WithTimeout(ctx, timeout) -} - -func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.ClientConn, getAgentMetadataClientFunc GetAgentMetadataClientFunc) (map[string]*Agent, error) { - agentRegistry := make(map[string]*Agent) - var agentDeployments []*Agent - - // Ensure that the old configuration is backward compatible - for taskType, agentID := range cfg.AgentForTaskTypes { - agentRegistry[taskType] = cfg.Agents[agentID] - } - - if len(cfg.DefaultAgent.Endpoint) != 0 { - agentDeployments = append(agentDeployments, &cfg.DefaultAgent) - } - agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...) - for _, agentDeployment := range agentDeployments { - client, err := getAgentMetadataClientFunc(context.Background(), agentDeployment, connectionCache) - if err != nil { - return nil, fmt.Errorf("failed to connect to agent [%v] with error: [%v]", agentDeployment, err) - } - - finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment) - defer cancel() - - res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{}) - if err != nil { - grpcStatus, ok := status.FromError(err) - if grpcStatus.Code() == codes.Unimplemented { - // we should not panic here, as we want to continue to support old agent settings - logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment) - continue - } - - if !ok { - return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err) - } - - return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err) - } - - agents := res.GetAgents() - for _, agent := range agents { - supportedTaskTypes := agent.SupportedTaskTypes - for _, supportedTaskType := range supportedTaskTypes { - agentRegistry[supportedTaskType] = agentDeployment - } - } - } - - return agentRegistry, nil -} - func newAgentPlugin() webapi.PluginEntry { - cfg := GetConfig() - connectionCache := make(map[*Agent]*grpc.ClientConn) - agentRegistry, err := initializeAgentRegistry(cfg, connectionCache, getAgentMetadataClientFunc) + cs, err := initializeClients(context.Background()) if err != nil { // We should wait for all agents to be up and running before starting the server - panic(err) + panic(fmt.Sprintf("failed to initialize clients with error: %v", err)) } + agentRegistry, err := initializeAgentRegistry(cs) + if err != nil { + panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err)) + } + + cfg := GetConfig() supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) logger.Infof(context.Background(), "Agent supports task types: %v", supportedTaskTypes) @@ -432,11 +278,10 @@ func newAgentPlugin() webapi.PluginEntry { SupportedTaskTypes: supportedTaskTypes, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { return &Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: cfg, - getClient: getClientFunc, - connectionCache: connectionCache, - agentRegistry: agentRegistry, + metricScope: iCtx.MetricsScope(), + cfg: cfg, + cs: cs, + agentRegistry: agentRegistry, }, nil }, } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index d5b5fae5fe..e66f46f1bc 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -6,57 +6,22 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "golang.org/x/exp/maps" - "google.golang.org/grpc" - - "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" - ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" webapiPlugin "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" - "github.com/flyteorg/flyte/flyteplugins/tests" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/promutils" - "github.com/flyteorg/flyte/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/exp/maps" ) -func TestSyncTask(t *testing.T) { - tCtx := getTaskContext(t) - taskReader := new(pluginCoreMocks.TaskReader) - - template := flyteIdlCore.TaskTemplate{ - Type: "api_task", - } - - taskReader.On("Read", mock.Anything).Return(&template, nil) - - tCtx.OnTaskReader().Return(taskReader) - - agentPlugin := newMockSyncAgentPlugin() - pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin) - plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("create_task_sync_test")) - assert.NoError(t, err) - - inputs, err := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) - assert.NoError(t, err) - basePrefix := storage.DataReference("fake://bucket/prefix/") - inputReader := &ioMocks.InputReader{} - inputReader.OnGetInputPrefixPath().Return(basePrefix) - inputReader.OnGetInputPath().Return(basePrefix + "/inputs.pb") - inputReader.OnGetMatch(mock.Anything).Return(inputs, nil) - tCtx.OnInputReader().Return(inputReader) - - phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, nil) - assert.Equal(t, true, phase.Phase().IsSuccess()) -} +const defaultAgentEndpoint = "localhost:8000" func TestPlugin(t *testing.T) { fakeSetupContext := pluginCoreMocks.SetupContext{} @@ -102,39 +67,6 @@ func TestPlugin(t *testing.T) { assert.Equal(t, agent.Endpoint, cfg.DefaultAgent.Endpoint) }) - t.Run("test getAgentMetadataClientFunc", func(t *testing.T) { - client, err := getAgentMetadataClientFunc(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) - assert.NoError(t, err) - assert.NotNil(t, client) - }) - - t.Run("test getClientFunc", func(t *testing.T) { - client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) - assert.NoError(t, err) - assert.NotNil(t, client) - }) - - t.Run("test getClientFunc more config", func(t *testing.T) { - client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{}) - assert.NoError(t, err) - assert.NotNil(t, client) - }) - - t.Run("test getClientFunc cache hit", func(t *testing.T) { - connectionCache := make(map[*Agent]*grpc.ClientConn) - agent := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"} - - client, err := getClientFunc(context.Background(), agent, connectionCache) - assert.NoError(t, err) - assert.NotNil(t, client) - assert.NotNil(t, client, connectionCache[agent]) - - cachedClient, err := getClientFunc(context.Background(), agent, connectionCache) - assert.NoError(t, err) - assert.NotNil(t, cachedClient) - assert.Equal(t, client, cachedClient) - }) - t.Run("test getFinalTimeout", func(t *testing.T) { timeout := getFinalTimeout("CreateTask", &Agent{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) assert.Equal(t, 1*time.Millisecond, timeout.Duration) @@ -334,8 +266,8 @@ func TestPlugin(t *testing.T) { }) } -func TestInitializeAgentRegistry(t *testing.T) { - mockClient := new(agentMocks.AgentMetadataServiceClient) +func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { + mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) mockRequest := &admin.ListAgentsRequest{} mockResponse := &admin.ListAgentsResponse{ Agents: []*admin.Agent{ @@ -346,16 +278,27 @@ func TestInitializeAgentRegistry(t *testing.T) { }, } - mockClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) - getAgentMetadataClientFunc := func(ctx context.Context, agent *Agent, connCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) { - return mockClient, nil + mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) + return mockMetadataServiceClient +} + +func TestInitializeAgentRegistry(t *testing.T) { + agentClients := make(map[string]service.AsyncAgentServiceClient) + agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) + agentClients[defaultAgentEndpoint] = &agentMocks.AsyncAgentServiceClient{} + agentMetadataClients[defaultAgentEndpoint] = getMockMetadataServiceClient() + + cs := &ClientSet{ + agentClients: agentClients, + agentMetadataClients: agentMetadataClients, } cfg := defaultConfig - cfg.Agents = map[string]*Agent{"custom_agent": {Endpoint: "localhost:80"}} + cfg.Agents = map[string]*Agent{"custom_agent": {Endpoint: defaultAgentEndpoint}} cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} - connectionCache := make(map[*Agent]*grpc.ClientConn) - agentRegistry, err := initializeAgentRegistry(&cfg, connectionCache, getAgentMetadataClientFunc) + err := SetConfig(&cfg) + assert.NoError(t, err) + agentRegistry, err := initializeAgentRegistry(cs) assert.NoError(t, err) // In golang, the order of keys in a map is random. So, we sort the keys before asserting.