diff --git a/go.mod b/go.mod index ac6ea312c..9cd9e86fb 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/agiledragon/gomonkey v2.0.2+incompatible github.com/apache/dubbo-getty v1.4.8 github.com/dubbogo/gost v1.12.5 + github.com/golang/mock v1.6.0 github.com/go-sql-driver/mysql v1.6.0 github.com/google/uuid v1.3.0 github.com/natefinch/lumberjack v2.0.0+incompatible diff --git a/go.sum b/go.sum index a67c1d561..8fc35615a 100644 --- a/go.sum +++ b/go.sum @@ -1839,4 +1839,4 @@ sourcegraph.com/sourcegraph/appdash-data v0.0.0-20151005221446-73f23eafcf67/go.m vimagination.zapto.org/byteio v0.0.0-20200222190125-d27cba0f0b10 h1:pxt6fVJP67Hxo1qk8JalUghLlk3abYByl+3e0JYfUlE= vimagination.zapto.org/byteio v0.0.0-20200222190125-d27cba0f0b10/go.mod h1:fl9OF22g6MTKgvHA1hqMXe/L7+ULWofVTwbC9loGu7A= vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d h1:Mp6WiHHuiwHaknxTdxJ8pvC9/B4pOgW1PamKGexG7Fs= -vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d/go.mod h1:zHGDKp2tyvF4IAfLti4pKYqCJucXYmmKMb3UMrCHK/4= +vimagination.zapto.org/memio v0.0.0-20200222190306-588ebc67b97d/go.mod h1:zHGDKp2tyvF4IAfLti4pKYqCJucXYmmKMb3UMrCHK/4= \ No newline at end of file diff --git a/pkg/rm/rm_api_test.go b/pkg/rm/rm_api_test.go new file mode 100644 index 000000000..fa1d8419b --- /dev/null +++ b/pkg/rm/rm_api_test.go @@ -0,0 +1,384 @@ +package rm + +import ( + "context" + "github.com/golang/mock/gomock" + "github.com/seata/seata-go/pkg/protocol/branch" + "reflect" + "sync" +) + +// MockResource is a mock of Resource interface. +type MockResource struct { + ctrl *gomock.Controller + recorder *MockResourceMockRecorder +} + +// MockResourceMockRecorder is the mock recorder for MockResource. +type MockResourceMockRecorder struct { + mock *MockResource +} + +// NewMockResource creates a new mock instance. +func NewMockResource(ctrl *gomock.Controller) *MockResource { + mock := &MockResource{ctrl: ctrl} + mock.recorder = &MockResourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResource) EXPECT() *MockResourceMockRecorder { + return m.recorder +} + +// GetBranchType mocks base method. +func (m *MockResource) GetBranchType() branch.BranchType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBranchType") + ret0, _ := ret[0].(branch.BranchType) + return ret0 +} + +// GetBranchType indicates an expected call of GetBranchType. +func (mr *MockResourceMockRecorder) GetBranchType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBranchType", reflect.TypeOf((*MockResource)(nil).GetBranchType)) +} + +// GetResourceGroupId mocks base method. +func (m *MockResource) GetResourceGroupId() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetResourceGroupId") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetResourceGroupId indicates an expected call of GetResourceGroupId. +func (mr *MockResourceMockRecorder) GetResourceGroupId() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResourceGroupId", reflect.TypeOf((*MockResource)(nil).GetResourceGroupId)) +} + +// GetResourceId mocks base method. +func (m *MockResource) GetResourceId() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetResourceId") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetResourceId indicates an expected call of GetResourceId. +func (mr *MockResourceMockRecorder) GetResourceId() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResourceId", reflect.TypeOf((*MockResource)(nil).GetResourceId)) +} + +// MockResourceManagerInbound is a mock of ResourceManagerInbound interface. +type MockResourceManagerInbound struct { + ctrl *gomock.Controller + recorder *MockResourceManagerInboundMockRecorder +} + +// MockResourceManagerInboundMockRecorder is the mock recorder for MockResourceManagerInbound. +type MockResourceManagerInboundMockRecorder struct { + mock *MockResourceManagerInbound +} + +// NewMockResourceManagerInbound creates a new mock instance. +func NewMockResourceManagerInbound(ctrl *gomock.Controller) *MockResourceManagerInbound { + mock := &MockResourceManagerInbound{ctrl: ctrl} + mock.recorder = &MockResourceManagerInboundMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceManagerInbound) EXPECT() *MockResourceManagerInboundMockRecorder { + return m.recorder +} + +// BranchCommit mocks base method. +func (m *MockResourceManagerInbound) BranchCommit(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchCommit", ctx, resource) + ret0, _ := ret[0].(branch.BranchStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchCommit indicates an expected call of BranchCommit. +func (mr *MockResourceManagerInboundMockRecorder) BranchCommit(ctx, resource interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockResourceManagerInbound)(nil).BranchCommit), ctx, resource) +} + +// BranchRollback mocks base method. +func (m *MockResourceManagerInbound) BranchRollback(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchRollback", ctx, resource) + ret0, _ := ret[0].(branch.BranchStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchRollback indicates an expected call of BranchRollback. +func (mr *MockResourceManagerInboundMockRecorder) BranchRollback(ctx, resource interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRollback", reflect.TypeOf((*MockResourceManagerInbound)(nil).BranchRollback), ctx, resource) +} + +// MockResourceManagerOutbound is a mock of ResourceManagerOutbound interface. +type MockResourceManagerOutbound struct { + ctrl *gomock.Controller + recorder *MockResourceManagerOutboundMockRecorder +} + +// MockResourceManagerOutboundMockRecorder is the mock recorder for MockResourceManagerOutbound. +type MockResourceManagerOutboundMockRecorder struct { + mock *MockResourceManagerOutbound +} + +// NewMockResourceManagerOutbound creates a new mock instance. +func NewMockResourceManagerOutbound(ctrl *gomock.Controller) *MockResourceManagerOutbound { + mock := &MockResourceManagerOutbound{ctrl: ctrl} + mock.recorder = &MockResourceManagerOutboundMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceManagerOutbound) EXPECT() *MockResourceManagerOutboundMockRecorder { + return m.recorder +} + +// BranchRegister mocks base method. +func (m *MockResourceManagerOutbound) BranchRegister(ctx context.Context, param BranchRegisterParam) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchRegister", ctx, param) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchRegister indicates an expected call of BranchRegister. +func (mr *MockResourceManagerOutboundMockRecorder) BranchRegister(ctx, param interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockResourceManagerOutbound)(nil).BranchRegister), ctx, param) +} + +// BranchReport mocks base method. +func (m *MockResourceManagerOutbound) BranchReport(ctx context.Context, param BranchReportParam) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchReport", ctx, param) + ret0, _ := ret[0].(error) + return ret0 +} + +// BranchReport indicates an expected call of BranchReport. +func (mr *MockResourceManagerOutboundMockRecorder) BranchReport(ctx, param interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockResourceManagerOutbound)(nil).BranchReport), ctx, param) +} + +// LockQuery mocks base method. +func (m *MockResourceManagerOutbound) LockQuery(ctx context.Context, param LockQueryParam) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LockQuery", ctx, param) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LockQuery indicates an expected call of LockQuery. +func (mr *MockResourceManagerOutboundMockRecorder) LockQuery(ctx, param interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockResourceManagerOutbound)(nil).LockQuery), ctx, param) +} + +// MockResourceManager is a mock of ResourceManager interface. +type MockResourceManager struct { + ctrl *gomock.Controller + recorder *MockResourceManagerMockRecorder +} + +// MockResourceManagerMockRecorder is the mock recorder for MockResourceManager. +type MockResourceManagerMockRecorder struct { + mock *MockResourceManager +} + +// NewMockResourceManager creates a new mock instance. +func NewMockResourceManager(ctrl *gomock.Controller) *MockResourceManager { + mock := &MockResourceManager{ctrl: ctrl} + mock.recorder = &MockResourceManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceManager) EXPECT() *MockResourceManagerMockRecorder { + return m.recorder +} + +// BranchCommit mocks base method. +func (m *MockResourceManager) BranchCommit(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchCommit", ctx, resource) + ret0, _ := ret[0].(branch.BranchStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchCommit indicates an expected call of BranchCommit. +func (mr *MockResourceManagerMockRecorder) BranchCommit(ctx, resource interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchCommit", reflect.TypeOf((*MockResourceManager)(nil).BranchCommit), ctx, resource) +} + +// BranchRegister mocks base method. +func (m *MockResourceManager) BranchRegister(ctx context.Context, param BranchRegisterParam) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchRegister", ctx, param) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchRegister indicates an expected call of BranchRegister. +func (mr *MockResourceManagerMockRecorder) BranchRegister(ctx, param interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRegister", reflect.TypeOf((*MockResourceManager)(nil).BranchRegister), ctx, param) +} + +// BranchReport mocks base method. +func (m *MockResourceManager) BranchReport(ctx context.Context, param BranchReportParam) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchReport", ctx, param) + ret0, _ := ret[0].(error) + return ret0 +} + +// BranchReport indicates an expected call of BranchReport. +func (mr *MockResourceManagerMockRecorder) BranchReport(ctx, param interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchReport", reflect.TypeOf((*MockResourceManager)(nil).BranchReport), ctx, param) +} + +// BranchRollback mocks base method. +func (m *MockResourceManager) BranchRollback(ctx context.Context, resource BranchResource) (branch.BranchStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BranchRollback", ctx, resource) + ret0, _ := ret[0].(branch.BranchStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BranchRollback indicates an expected call of BranchRollback. +func (mr *MockResourceManagerMockRecorder) BranchRollback(ctx, resource interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BranchRollback", reflect.TypeOf((*MockResourceManager)(nil).BranchRollback), ctx, resource) +} + +// GetBranchType mocks base method. +func (m *MockResourceManager) GetBranchType() branch.BranchType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBranchType") + ret0, _ := ret[0].(branch.BranchType) + return ret0 +} + +// GetBranchType indicates an expected call of GetBranchType. +func (mr *MockResourceManagerMockRecorder) GetBranchType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBranchType", reflect.TypeOf((*MockResourceManager)(nil).GetBranchType)) +} + +// GetCachedResources mocks base method. +func (m *MockResourceManager) GetCachedResources() *sync.Map { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCachedResources") + ret0, _ := ret[0].(*sync.Map) + return ret0 +} + +// GetCachedResources indicates an expected call of GetCachedResources. +func (mr *MockResourceManagerMockRecorder) GetCachedResources() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCachedResources", reflect.TypeOf((*MockResourceManager)(nil).GetCachedResources)) +} + +// LockQuery mocks base method. +func (m *MockResourceManager) LockQuery(ctx context.Context, param LockQueryParam) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LockQuery", ctx, param) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LockQuery indicates an expected call of LockQuery. +func (mr *MockResourceManagerMockRecorder) LockQuery(ctx, param interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LockQuery", reflect.TypeOf((*MockResourceManager)(nil).LockQuery), ctx, param) +} + +// RegisterResource mocks base method. +func (m *MockResourceManager) RegisterResource(resource Resource) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterResource", resource) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterResource indicates an expected call of RegisterResource. +func (mr *MockResourceManagerMockRecorder) RegisterResource(resource interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterResource", reflect.TypeOf((*MockResourceManager)(nil).RegisterResource), resource) +} + +// UnregisterResource mocks base method. +func (m *MockResourceManager) UnregisterResource(resource Resource) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnregisterResource", resource) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnregisterResource indicates an expected call of UnregisterResource. +func (mr *MockResourceManagerMockRecorder) UnregisterResource(resource interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterResource", reflect.TypeOf((*MockResourceManager)(nil).UnregisterResource), resource) +} + +// MockResourceManagerGetter is a mock of ResourceManagerGetter interface. +type MockResourceManagerGetter struct { + ctrl *gomock.Controller + recorder *MockResourceManagerGetterMockRecorder +} + +// MockResourceManagerGetterMockRecorder is the mock recorder for MockResourceManagerGetter. +type MockResourceManagerGetterMockRecorder struct { + mock *MockResourceManagerGetter +} + +// NewMockResourceManagerGetter creates a new mock instance. +func NewMockResourceManagerGetter(ctrl *gomock.Controller) *MockResourceManagerGetter { + mock := &MockResourceManagerGetter{ctrl: ctrl} + mock.recorder = &MockResourceManagerGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockResourceManagerGetter) EXPECT() *MockResourceManagerGetterMockRecorder { + return m.recorder +} + +// GetResourceManager mocks base method. +func (m *MockResourceManagerGetter) GetResourceManager(branchType branch.BranchType) ResourceManager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetResourceManager", branchType) + ret0, _ := ret[0].(ResourceManager) + return ret0 +} + +// GetResourceManager indicates an expected call of GetResourceManager. +func (mr *MockResourceManagerGetterMockRecorder) GetResourceManager(branchType interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResourceManager", reflect.TypeOf((*MockResourceManagerGetter)(nil).GetResourceManager), branchType) +} diff --git a/pkg/rm/rm_cache_test.go b/pkg/rm/rm_cache_test.go new file mode 100644 index 000000000..2560a6018 --- /dev/null +++ b/pkg/rm/rm_cache_test.go @@ -0,0 +1,30 @@ +package rm + +import ( + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/protocol/branch" +) + +func TestGetRmCacheInstance(t *testing.T) { + + ctl := gomock.NewController(t) + + mockResourceManager := NewMockResourceManager(ctl) + mockResourceManager.EXPECT().GetBranchType().Return(branch.BranchTypeTCC) + + tests := struct { + name string + want *ResourceManagerCache + }{"test1", &ResourceManagerCache{}} + + t.Run(tests.name, func(t *testing.T) { + GetRmCacheInstance().RegisterResourceManager(mockResourceManager) + actual := GetRmCacheInstance().GetResourceManager(branch.BranchTypeTCC) + assert.Equalf(t, mockResourceManager, actual, "GetRmCacheInstance()") + }) + +} diff --git a/pkg/rm/rm_remoting_test.go b/pkg/rm/rm_remoting_test.go new file mode 100644 index 000000000..5f69f68b9 --- /dev/null +++ b/pkg/rm/rm_remoting_test.go @@ -0,0 +1,18 @@ +package rm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetRMRemotingInstance(t *testing.T) { + tests := struct { + name string + want *RMRemoting + }{"test1", &RMRemoting{}} + + t.Run(tests.name, func(t *testing.T) { + assert.Equalf(t, tests.want, GetRMRemotingInstance(), "GetRMRemotingInstance()") + }) +} diff --git a/pkg/rm/tcc/tcc_service_test.go b/pkg/rm/tcc/tcc_service_test.go index be778631c..d4eaea15a 100644 --- a/pkg/rm/tcc/tcc_service_test.go +++ b/pkg/rm/tcc/tcc_service_test.go @@ -19,25 +19,30 @@ package tcc import ( "context" + "fmt" "os" "reflect" + "sync" "testing" "time" - "github.com/seata/seata-go/pkg/common/log" - "github.com/agiledragon/gomonkey" + "github.com/stretchr/testify/assert" + "github.com/seata/seata-go/pkg/common" + "github.com/seata/seata-go/pkg/common/log" "github.com/seata/seata-go/pkg/common/net" "github.com/seata/seata-go/pkg/rm" "github.com/seata/seata-go/pkg/tm" + "github.com/seata/seata-go/sample/tcc/dubbo/client/service" testdata2 "github.com/seata/seata-go/testdata" - "github.com/stretchr/testify/assert" ) var ( testTccServiceProxy *TCCServiceProxy testBranchID = int64(121324345353) + names []interface{} + values = make([]reflect.Value, 0, 2) ) func InitMock() { @@ -56,7 +61,7 @@ func InitMock() { gomonkey.ApplyMethod(reflect.TypeOf(testTccServiceProxy), "RegisterResource", registerResource) gomonkey.ApplyMethod(reflect.TypeOf(testTccServiceProxy), "Prepare", prepare) gomonkey.ApplyMethod(reflect.TypeOf(rm.GetRMRemotingInstance()), "BranchRegister", branchRegister) - testTccServiceProxy, _ = NewTCCServiceProxy(testdata2.GetTestTwoPhaseService()) + testTccServiceProxy, _ = NewTCCServiceProxy(GetTestTwoPhaseService()) } func TestMain(m *testing.M) { @@ -225,3 +230,88 @@ func TestRegisteBranch(t *testing.T) { bizContext := tm.GetBusinessActionContext(ctx) assert.Equal(t, testBranchID, bizContext.BranchId) } + +func TestNewTCCServiceProxy(t *testing.T) { + type args struct { + service interface{} + } + + userProvider := &service.UserProvider{} + args1 := args{service: userProvider} + args2 := args{service: userProvider} + + twoPhaseAction1, _ := rm.ParseTwoPhaseAction(userProvider) + twoPhaseAction2, _ := rm.ParseTwoPhaseAction(userProvider) + + tests := []struct { + name string + args args + want *TCCServiceProxy + wantErr assert.ErrorAssertionFunc + }{ + {"test1", args1, &TCCServiceProxy{ + TCCResource: &TCCResource{ + ResourceGroupId: `default:"DEFAULT"`, + AppName: "seata-go-mock-app-name", + TwoPhaseAction: twoPhaseAction1}}, assert.NoError, + }, + {"test2", args2, &TCCServiceProxy{ + TCCResource: &TCCResource{ + ResourceGroupId: `default:"DEFAULT"`, + AppName: "seata-go-mock-app-name", + TwoPhaseAction: twoPhaseAction2}}, assert.NoError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewTCCServiceProxy(tt.args.service) + if !tt.wantErr(t, err, fmt.Sprintf("NewTCCServiceProxy(%v)", tt.args.service)) { + return + } + assert.Equalf(t, tt.want, got, "NewTCCServiceProxy(%v)", tt.args.service) + }) + } +} + +func TestTCCGetTransactionInfo(t1 *testing.T) { + type fields struct { + referenceName string + registerResourceOnce sync.Once + TCCResource *TCCResource + } + + userProvider := &service.UserProvider{} + twoPhaseAction1, _ := rm.ParseTwoPhaseAction(userProvider) + + tests := struct { + name string + fields fields + want tm.TransactionInfo + }{ + "test1", fields{ + referenceName: "test1", + registerResourceOnce: sync.Once{}, + TCCResource: &TCCResource{ + ResourceGroupId: "default1", + AppName: "app1", + TwoPhaseAction: twoPhaseAction1, + }, + }, + tm.TransactionInfo{Name: "TwoPhaseDemoService", TimeOut: 10000, Propagation: 0, LockRetryInternal: 0, LockRetryTimes: 0}, + } + + t1.Run(tests.name, func(t1 *testing.T) { + t := &TCCServiceProxy{ + referenceName: tests.fields.referenceName, + registerResourceOnce: sync.Once{}, + TCCResource: tests.fields.TCCResource, + } + assert.Equalf(t1, tests.want, t.GetTransactionInfo(), "GetTransactionInfo()") + }) + +} + +func GetTestTwoPhaseService() rm.TwoPhaseInterface { + return &testdata2.TestTwoPhaseService{} +} diff --git a/pkg/rm/two_phase.go b/pkg/rm/two_phase.go index ff1330477..fa57ab63b 100644 --- a/pkg/rm/two_phase.go +++ b/pkg/rm/two_phase.go @@ -132,7 +132,7 @@ func (t *TwoPhaseAction) GetActionName() string { func IsTwoPhaseAction(v interface{}) bool { m, err := ParseTwoPhaseAction(v) - return m != nil && err != nil + return m != nil && err == nil } func ParseTwoPhaseAction(v interface{}) (*TwoPhaseAction, error) { diff --git a/pkg/rm/two_phase_test.go b/pkg/rm/two_phase_test.go index 7617ec8b1..8e0bb936f 100644 --- a/pkg/rm/two_phase_test.go +++ b/pkg/rm/two_phase_test.go @@ -22,8 +22,11 @@ import ( "fmt" "testing" - "github.com/seata/seata-go/pkg/tm" "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/tm" + "github.com/seata/seata-go/sample/tcc/dubbo/client/service" + testdata2 "github.com/seata/seata-go/testdata" ) func TestParseTwoPhaseActionGetMethodName(t *testing.T) { @@ -124,7 +127,7 @@ func TestParseTwoPhaseActionGetMethodName(t *testing.T) { } type TwoPhaseDemoService1 struct { - TwoPhasePrepare func(ctx context.Context, params interface{}) (bool, error) `seataTwoPhaseAction:"prepare" seataTwoPhaseServiceName:"TwoPhaseDemoService"` + TwoPhasePrepare func(ctx context.Context, params interface{}) (bool, error) `seataTwoPhaseAction:"prepare" seataTwoPhaseServiceName:"twoPhaseDemoService"` TwoPhaseCommit func(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error) `seataTwoPhaseAction:"commit"` TwoPhaseRollback func(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error) `seataTwoPhaseAction:"rollback"` TwoPhaseDemoService func() string @@ -142,7 +145,7 @@ func NewTwoPhaseDemoService1() *TwoPhaseDemoService1 { return true, nil }, TwoPhaseDemoService: func() string { - return "TwoPhaseDemoService" + return "twoPhaseDemoService" }, } } @@ -154,7 +157,7 @@ func TestParseTwoPhaseActionExecuteMethod1(t *testing.T) { assert.Equal(t, "TwoPhasePrepare", twoPhaseService.prepareMethodName) assert.Equal(t, "TwoPhaseCommit", twoPhaseService.commitMethodName) assert.Equal(t, "TwoPhaseRollback", twoPhaseService.rollbackMethodName) - assert.Equal(t, "TwoPhaseDemoService", twoPhaseService.actionName) + assert.Equal(t, "twoPhaseDemoService", twoPhaseService.actionName) resp, err := twoPhaseService.Prepare(ctx, 11) assert.Equal(t, false, resp) @@ -168,7 +171,7 @@ func TestParseTwoPhaseActionExecuteMethod1(t *testing.T) { assert.Equal(t, true, resp) assert.Nil(t, err) - assert.Equal(t, "TwoPhaseDemoService", twoPhaseService.GetActionName()) + assert.Equal(t, "twoPhaseDemoService", twoPhaseService.GetActionName()) } type TwoPhaseDemoService2 struct { @@ -187,7 +190,7 @@ func (t *TwoPhaseDemoService2) Rollback(ctx context.Context, businessActionConte } func (t *TwoPhaseDemoService2) GetActionName() string { - return "TwoPhaseDemoService2" + return "TestTwoPhaseDemoService" } func TestParseTwoPhaseActionExecuteMethod2(t *testing.T) { @@ -206,3 +209,80 @@ func TestParseTwoPhaseActionExecuteMethod2(t *testing.T) { assert.Equal(t, false, resp) assert.Equal(t, "execute two phase rollback method, xid 1234", err.Error()) } + +func TestIsTwoPhaseAction(t *testing.T) { + + userProvider := &testdata2.TestTwoPhaseService{} + userProvider1 := service.UserProviderInstance + type args struct { + v interface{} + } + + tests := []struct { + name string + args args + want bool + }{ + {"test1", args{v: userProvider}, true}, + {"test2", args{v: userProvider1}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, IsTwoPhaseAction(tt.args.v), "IsTwoPhaseAction(%v)", tt.args.v) + }) + } + +} + +func TestParseTwoPhaseAction(t *testing.T) { + + type args struct { + v interface{} + } + + userProvider := service.UserProviderInstance + twoPhaseAction, _ := ParseTwoPhaseAction(userProvider) + args1 := args{v: userProvider} + + tests := struct { + name string + args args + want *TwoPhaseAction + wantErr assert.ErrorAssertionFunc + }{"test1", args1, twoPhaseAction, assert.NoError} + + t.Run(tests.name, func(t *testing.T) { + got, err := ParseTwoPhaseAction(tests.args.v) + if !tests.wantErr(t, err, fmt.Sprintf("ParseTwoPhaseAction(%v)", tests.args.v)) { + return + } + assert.Equalf(t, tests.want.GetTwoPhaseService(), got.GetTwoPhaseService(), "ParseTwoPhaseAction(%v)", tests.args.v) + }) + +} + +func TestParseTwoPhaseActionByInterface(t *testing.T) { + type args struct { + v interface{} + } + + userProvider := &service.UserProvider{} + twoPhaseAction, _ := ParseTwoPhaseAction(userProvider) + args1 := args{v: userProvider} + + tests := struct { + name string + args args + want *TwoPhaseAction + wantErr assert.ErrorAssertionFunc + }{"test1", args1, twoPhaseAction, assert.NoError} + + t.Run(tests.name, func(t *testing.T) { + got, err := ParseTwoPhaseActionByInterface(tests.args.v) + if !tests.wantErr(t, err, fmt.Sprintf("ParseTwoPhaseActionByInterface(%v)", tests.args.v)) { + return + } + assert.Equalf(t, tests.want, got, "ParseTwoPhaseActionByInterface(%v)", tests.args.v) + }) +} diff --git a/sample/tcc/dubbo/client/service/user_provider.go b/sample/tcc/dubbo/client/service/user_provider.go index 943a40d02..4b16e3bfd 100644 --- a/sample/tcc/dubbo/client/service/user_provider.go +++ b/sample/tcc/dubbo/client/service/user_provider.go @@ -19,12 +19,13 @@ package service import ( "context" + "fmt" "github.com/seata/seata-go/pkg/tm" ) var ( - UserProviderInstance = &UserProvider{} + UserProviderInstance = NewTwoPhaseDemoService() ) type UserProvider struct { @@ -33,3 +34,20 @@ type UserProvider struct { Rollback func(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error) `seataTwoPhaseAction:"rollback"` GetActionName func() string } + +func NewTwoPhaseDemoService() *UserProvider { + return &UserProvider{ + Prepare: func(ctx context.Context, params ...interface{}) (bool, error) { + return false, fmt.Errorf("execute two phase prepare method, param %v", params) + }, + Commit: func(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error) { + return false, fmt.Errorf("execute two phase commit method, xid %v", businessActionContext.Xid) + }, + Rollback: func(ctx context.Context, businessActionContext *tm.BusinessActionContext) (bool, error) { + return true, nil + }, + GetActionName: func() string { + return "TwoPhaseDemoService" + }, + } +} diff --git a/testdata/mock_tcc.go b/testdata/mock_tcc.go index 7aac1a18a..683d19dee 100644 --- a/testdata/mock_tcc.go +++ b/testdata/mock_tcc.go @@ -2,8 +2,6 @@ package testdata import ( "context" - - "github.com/seata/seata-go/pkg/rm" "github.com/seata/seata-go/pkg/tm" ) @@ -29,7 +27,3 @@ func (*TestTwoPhaseService) Rollback(ctx context.Context, businessActionContext func (*TestTwoPhaseService) GetActionName() string { return ActionName } - -func GetTestTwoPhaseService() rm.TwoPhaseInterface { - return &TestTwoPhaseService{} -}