From 06f860b3c92cfef6b17e1a0fd81f785004c6f1d9 Mon Sep 17 00:00:00 2001 From: Aniruddha Maru Date: Tue, 3 Aug 2021 21:45:16 -0700 Subject: [PATCH] Run sub-tests in a group so teardownsuite is called in the right order Define new interface `Copy` to create copies of suite object for parallel subtests --- suite/doc.go | 9 ++- suite/interfaces.go | 11 ++- suite/suite.go | 173 +++++++++++++++++++++++--------------------- suite/suite_test.go | 133 ++++++++++++++++++++++++++++++++-- 4 files changed, 234 insertions(+), 92 deletions(-) diff --git a/suite/doc.go b/suite/doc.go index f91a245d3..433c3adc7 100644 --- a/suite/doc.go +++ b/suite/doc.go @@ -6,10 +6,7 @@ // implement). // // A testing suite is usually built by first extending the built-in -// suite functionality from suite.Suite in testify. Alternatively, -// you could reproduce that logic on your own if you wanted (you -// just need to implement the TestingSuite interface from -// suite/interfaces.go). +// suite functionality from suite.Suite in testify. // // After that, you can implement any of the interfaces in // suite/interfaces.go to add setup/teardown functionality to your @@ -23,6 +20,10 @@ // identity that "go test" is already looking for (i.e. // func(*testing.T)). // +// To be able to run parallel sub-tests, your testing suite should +// implement "CopySuite". This may or may not be a deepcopy depending +// on the fields in the struct. +// // Regular expression to select test suites specified command-line // argument "-run". Regular expression to select the methods // of test suites specified command-line argument "-m". diff --git a/suite/interfaces.go b/suite/interfaces.go index 8b98a8af2..bec7a47cf 100644 --- a/suite/interfaces.go +++ b/suite/interfaces.go @@ -6,7 +6,16 @@ import "testing" // generated by 'go test'. type TestingSuite interface { T() *testing.T - SetT(*testing.T) + setT(*testing.T) + clearT() +} + +// CopySuite indicates a copyable struct, deepcopy vs shallow is +// implementation detail of the application. +type CopySuite interface { + // Copy creates a copy of the calling suite object. The returned + // object must be the same concrete type as caller + Copy() TestingSuite } // SetupAllSuite has a SetupSuite method, which will run before the diff --git a/suite/suite.go b/suite/suite.go index b9b5d1c56..b3a61ae6e 100644 --- a/suite/suite.go +++ b/suite/suite.go @@ -30,13 +30,22 @@ func (suite *Suite) T() *testing.T { return suite.t } -// SetT sets the current *testing.T context. -func (suite *Suite) SetT(t *testing.T) { +// setT sets the current *testing.T context. +func (suite *Suite) setT(t *testing.T) { + if suite.t != nil { + panic("suite.t already set, can't overwrite") + } suite.t = t suite.Assertions = assert.New(t) suite.require = require.New(t) } +func (suite *Suite) clearT() { + suite.t = nil + suite.Assertions = nil + suite.require = nil +} + // Require returns a require context for suite. func (suite *Suite) Require() *require.Assertions { if suite.require == nil { @@ -69,11 +78,16 @@ func failOnPanic(t *testing.T) { // called in place of t.Run(name, func(t *testing.T)) in test suite code. // The passed-in func will be executed as a subtest with a fresh instance of t. // Provides compatibility with go test pkg -run TestSuite/TestName/SubTestName. +// Deprecated: This method doesn't handle parallel sub-tests and will be removed in v2. func (suite *Suite) Run(name string, subtest func()) bool { oldT := suite.T() - defer suite.SetT(oldT) + defer func() { + suite.clearT() + suite.setT(oldT) + }() return oldT.Run(name, func(t *testing.T) { - suite.SetT(t) + suite.clearT() + suite.setT(t) subtest() }) } @@ -83,8 +97,6 @@ func (suite *Suite) Run(name string, subtest func()) bool { func Run(t *testing.T, suite TestingSuite) { defer failOnPanic(t) - suite.SetT(t) - var suiteSetupDone bool var stats *SuiteInformation @@ -96,84 +108,95 @@ func Run(t *testing.T, suite TestingSuite) { methodFinder := reflect.TypeOf(suite) suiteName := methodFinder.Elem().Name() - for i := 0; i < methodFinder.NumMethod(); i++ { - method := methodFinder.Method(i) + t.Run("All", func(t *testing.T) { + defer failOnPanic(t) - ok, err := methodFilter(method.Name) - if err != nil { - fmt.Fprintf(os.Stderr, "testify: invalid regexp for -m: %s\n", err) - os.Exit(1) - } + suite.setT(t) - if !ok { - continue - } + for i := 0; i < methodFinder.NumMethod(); i++ { + method := methodFinder.Method(i) - if !suiteSetupDone { - if stats != nil { - stats.Start = time.Now() + ok, err := methodFilter(method.Name) + if err != nil { + fmt.Fprintf(os.Stderr, "testify: invalid regexp for -m: %s\n", err) + os.Exit(1) } - if setupAllSuite, ok := suite.(SetupAllSuite); ok { - setupAllSuite.SetupSuite() + if !ok { + continue } - suiteSetupDone = true - } + if !suiteSetupDone { + if stats != nil { + stats.Start = time.Now() + } - test := testing.InternalTest{ - Name: method.Name, - F: func(t *testing.T) { - parentT := suite.T() - suite.SetT(t) - defer failOnPanic(t) - defer func() { - if stats != nil { - passed := !t.Failed() - stats.end(method.Name, passed) - } + if setupAllSuite, ok := suite.(SetupAllSuite); ok { + setupAllSuite.SetupSuite() + } - if afterTestSuite, ok := suite.(AfterTest); ok { - afterTestSuite.AfterTest(suiteName, method.Name) - } + suiteSetupDone = true + } - if tearDownTestSuite, ok := suite.(TearDownTestSuite); ok { - tearDownTestSuite.TearDownTest() + test := testing.InternalTest{ + Name: method.Name, + F: func(t *testing.T) { + defer failOnPanic(t) + childSuite := suite + if c, ok := suite.(CopySuite); ok { + childSuite = c.Copy() + childSuite.clearT() + } + childSuite.setT(t) + defer func() { + if _, ok := suite.(CopySuite); !ok { + defer suite.clearT() + } + if stats != nil { + passed := !t.Failed() + stats.end(method.Name, passed) + } + + if tearDownTestSuite, ok := childSuite.(TearDownTestSuite); ok { + tearDownTestSuite.TearDownTest() + } + + if afterTestSuite, ok := childSuite.(AfterTest); ok { + afterTestSuite.AfterTest(suiteName, method.Name) + } + }() + + if beforeTestSuite, ok := childSuite.(BeforeTest); ok { + beforeTestSuite.BeforeTest(methodFinder.Elem().Name(), method.Name) + } + if setupTestSuite, ok := childSuite.(SetupTestSuite); ok { + setupTestSuite.SetupTest() } - suite.SetT(parentT) - }() - - if setupTestSuite, ok := suite.(SetupTestSuite); ok { - setupTestSuite.SetupTest() - } - if beforeTestSuite, ok := suite.(BeforeTest); ok { - beforeTestSuite.BeforeTest(methodFinder.Elem().Name(), method.Name) - } - - if stats != nil { - stats.start(method.Name) - } + if stats != nil { + stats.start(method.Name) + } - method.Func.Call([]reflect.Value{reflect.ValueOf(suite)}) - }, + method.Func.Call([]reflect.Value{reflect.ValueOf(childSuite)}) + }, + } + tests = append(tests, test) } - tests = append(tests, test) - } + + suite.clearT() + runTests(t, tests) + }) if suiteSetupDone { - defer func() { - if tearDownAllSuite, ok := suite.(TearDownAllSuite); ok { - tearDownAllSuite.TearDownSuite() - } + suite.setT(t) + if tearDownAllSuite, ok := suite.(TearDownAllSuite); ok { + tearDownAllSuite.TearDownSuite() + } - if suiteWithStats, measureStats := suite.(WithStats); measureStats { - stats.End = time.Now() - suiteWithStats.HandleStats(suiteName, stats) - } - }() + if suiteWithStats, measureStats := suite.(WithStats); measureStats { + stats.End = time.Now() + suiteWithStats.HandleStats(suiteName, stats) + } } - - runTests(t, tests) } // Filtering method according to set regular expression @@ -185,25 +208,13 @@ func methodFilter(name string) (bool, error) { return regexp.MatchString(*matchMethod, name) } -func runTests(t testing.TB, tests []testing.InternalTest) { +func runTests(t *testing.T, tests []testing.InternalTest) { if len(tests) == 0 { t.Log("warning: no tests to run") return } - r, ok := t.(runner) - if !ok { // backwards compatibility with Go 1.6 and below - if !testing.RunTests(allTestsFilter, tests) { - t.Fail() - } - return - } - for _, test := range tests { - r.Run(test.Name, test.F) + t.Run(test.Name, test.F) } } - -type runner interface { - Run(name string, f func(t *testing.T)) bool -} diff --git a/suite/suite_test.go b/suite/suite_test.go index 963a25258..8f584de4c 100644 --- a/suite/suite_test.go +++ b/suite/suite_test.go @@ -4,11 +4,13 @@ import ( "bytes" "errors" "flag" + "fmt" "io/ioutil" "math/rand" "os" "os/exec" "strings" + "sync" "testing" "time" @@ -168,40 +170,64 @@ type SuiteTester struct { TimeBefore []time.Time TimeAfter []time.Time + + SuiteT *testing.T + SuiteGroupT *testing.T + TestT map[string]*testing.T } // The SetupSuite method will be run by testify once, at the very // start of the testing suite, before any tests are run. func (suite *SuiteTester) SetupSuite() { suite.SetupSuiteRunCount++ + suite.SuiteGroupT = suite.T() } func (suite *SuiteTester) BeforeTest(suiteName, testName string) { suite.SuiteNameBefore = append(suite.SuiteNameBefore, suiteName) suite.TestNameBefore = append(suite.TestNameBefore, testName) suite.TimeBefore = append(suite.TimeBefore, time.Now()) + suite.TestT[testName] = suite.T() } func (suite *SuiteTester) AfterTest(suiteName, testName string) { suite.SuiteNameAfter = append(suite.SuiteNameAfter, suiteName) suite.TestNameAfter = append(suite.TestNameAfter, testName) suite.TimeAfter = append(suite.TimeAfter, time.Now()) + // T should be from the sub-test + assert.True(suite.T(), suite.TestT[testName] == suite.T()) } // The TearDownSuite method will be run by testify once, at the very // end of the testing suite, after all tests have been run. func (suite *SuiteTester) TearDownSuite() { suite.TearDownSuiteRunCount++ + // T should be from the suite + assert.True(suite.T(), suite.SuiteT == suite.T()) } // The SetupTest method will be run before every test in the suite. func (suite *SuiteTester) SetupTest() { suite.SetupTestRunCount++ + var subSuites []*testing.T + for _, value := range suite.TestT { + subSuites = append(subSuites, value) + } + // T should be from one of the tests, not suite + assert.Contains(suite.T(), subSuites, suite.T()) + assert.False(suite.T(), suite.SuiteT == suite.T() || suite.SuiteGroupT == suite.T()) } // The TearDownTest method will be run after every test in the suite. func (suite *SuiteTester) TearDownTest() { suite.TearDownTestRunCount++ + var subSuites []*testing.T + for _, value := range suite.TestT { + subSuites = append(subSuites, value) + } + // T should be from one of the tests, not suite + assert.Contains(suite.T(), subSuites, suite.T()) + assert.False(suite.T(), suite.SuiteT == suite.T() || suite.SuiteGroupT == suite.T()) } // Every method in a testing suite that begins with "Test" will be run @@ -213,6 +239,8 @@ func (suite *SuiteTester) TestOne() { suite.TestOneRunCount++ assert.Equal(suite.T(), suite.TestOneRunCount, beforeCount+1) suite.Equal(suite.TestOneRunCount, beforeCount+1) + // T should be from the right test + assert.True(suite.T(), suite.T() == suite.TestT["TestOne"]) } // TestTwo is another example of a test. @@ -221,10 +249,14 @@ func (suite *SuiteTester) TestTwo() { suite.TestTwoRunCount++ assert.NotEqual(suite.T(), suite.TestTwoRunCount, beforeCount) suite.NotEqual(suite.TestTwoRunCount, beforeCount) + // T should be from the right test + assert.True(suite.T(), suite.T() == suite.TestT["TestTwo"]) } func (suite *SuiteTester) TestSkip() { suite.T().Skip() + // T should be from the right test + assert.True(suite.T(), suite.T() == suite.TestT["TestSkip"]) } // NonTestMethod does not begin with "Test", so it will not be run by @@ -237,6 +269,9 @@ func (suite *SuiteTester) NonTestMethod() { func (suite *SuiteTester) TestSubtest() { suite.TestSubtestRunCount++ + // T should be from the right test + assert.True(suite.T(), suite.T() == suite.TestT["TestSubtest"]) + for _, t := range []struct { testName string }{ @@ -282,7 +317,10 @@ func (suite *SuiteSkipTester) TearDownSuite() { // TestRunSuite will be run by the 'go test' command, so within it, we // can run our suite using the Run(*testing.T, TestingSuite) function. func TestRunSuite(t *testing.T) { - suiteTester := new(SuiteTester) + suiteTester := &SuiteTester{ + TestT: make(map[string]*testing.T), + SuiteT: t, + } Run(t, suiteTester) // Normally, the test would end here. The following are simply @@ -343,11 +381,10 @@ func TestRunSuite(t *testing.T) { suiteSkipTester := new(SuiteSkipTester) Run(t, suiteSkipTester) - // The suite was only run once, so the SetupSuite and TearDownSuite - // methods should have each been run only once, even though SetupSuite - // called Skip() + // The suite was only run once, so SetupSuite method should have been + // run only once. Since SetupSuite called Skip(), Teardown isn't called. assert.Equal(t, suiteSkipTester.SetupSuiteRunCount, 1) - assert.Equal(t, suiteSkipTester.TearDownSuiteRunCount, 1) + assert.Equal(t, suiteSkipTester.TearDownSuiteRunCount, 0) } @@ -380,7 +417,7 @@ func TestSkippingSuiteSetup(t *testing.T) { func TestSuiteGetters(t *testing.T) { suite := new(SuiteTester) - suite.SetT(t) + suite.setT(t) assert.NotNil(t, suite.Assert()) assert.Equal(t, suite.Assertions, suite.Assert()) assert.NotNil(t, suite.Require()) @@ -470,6 +507,7 @@ func (s *CallOrderSuite) TearDownSuite() { s.call("TearDownSuite") assert.Equal(s.T(), "SetupSuite;SetupTest;Test A;TearDownTest;SetupTest;Test B;TearDownTest;TearDownSuite", strings.Join(s.callOrder, ";")) } + func (s *CallOrderSuite) SetupTest() { s.call("SetupTest") } @@ -587,3 +625,86 @@ func (s *FailfastSuite) Test_B_Passes() { s.call("Test B Passes") s.Require().True(true) } + +type parallelSuiteData struct { + calls []string + callsIndex map[string]int + parallelSuiteT map[string]*parallelSuite +} + +type parallelSuite struct { + Suite + mutex *sync.Mutex + data *parallelSuiteData +} + +func (s *parallelSuite) call(method string) { + time.Sleep(time.Duration(rand.Intn(300)) * time.Millisecond) + s.mutex.Lock() + defer s.mutex.Unlock() + s.data.calls = append(s.data.calls, method) + s.data.callsIndex[method] = len(s.data.calls) - 1 +} + +func (s *parallelSuite) Copy() TestingSuite { + // shallow copy since data is protected by mutex anyway + c := *s + return &c +} + +func TestSuiteParallel(t *testing.T) { + data := parallelSuiteData{ + calls: []string{}, + callsIndex: make(map[string]int, 8), + parallelSuiteT: make(map[string]*parallelSuite, 2), + } + s := ¶llelSuite{mutex: &sync.Mutex{}, data: &data} + Run(t, s) +} + +func (s *parallelSuite) SetupSuite() { + s.call("SetupSuite") +} + +func (s *parallelSuite) TearDownSuite() { + s.call("TearDownSuite") + s.mutex.Lock() + defer s.mutex.Unlock() + // first 3 calls and last call is known ordering + assert.Equal(s.T(), []string{"SetupSuite", "BeforeTest Test_A", "BeforeTest Test_B"}, s.data.calls[:3]) + assert.Equal(s.T(), "TearDownSuite", s.data.calls[len(s.data.calls)-1]) + // should have these calls + assert.Subset(s.T(), s.data.calls, []string{"Test_A", "AfterTest Test_A", "Test_B", "AfterTest Test_B"}) + // there won't be any other ordering guarantees between tests A and B since they are run in parallel, + // but verify that AfterTest is run after the test + assert.Greater(s.T(), s.data.callsIndex["AfterTest Test_A"], s.data.callsIndex["Test_A"]) + assert.Greater(s.T(), s.data.callsIndex["AfterTest Test_B"], s.data.callsIndex["Test_B"]) + // verify that copies of s are created correctly + assert.NotEqual(s.T(), s, s.data.parallelSuiteT["Test_A"]) + assert.NotEqual(s.T(), s, s.data.parallelSuiteT["Test_B"]) + assert.NotEqual(s.T(), s.data.parallelSuiteT["Test_A"], s.data.parallelSuiteT["Test_B"]) +} + +func (s *parallelSuite) BeforeTest(suiteName, testName string) { + s.call(fmt.Sprintf("BeforeTest %s", testName)) +} + +func (s *parallelSuite) AfterTest(suiteName, testName string) { + s.call(fmt.Sprintf("AfterTest %s", testName)) +} + +func (s *parallelSuite) Test_A() { + s.T().Parallel() + s.call("Test_A") + s.mutex.Lock() + defer s.mutex.Unlock() + s.data.parallelSuiteT["Test_A"] = s +} + +func (s *parallelSuite) Test_B() { + s.T().Parallel() + s.call("Test_B") + s.mutex.Lock() + defer s.mutex.Unlock() + s.data.parallelSuiteT["Test_B"] = s +}