diff --git a/internal/examples/supervisor/supervisor/supervisor_test.go b/internal/examples/supervisor/supervisor/supervisor_test.go index 4568720..fbadeea 100644 --- a/internal/examples/supervisor/supervisor/supervisor_test.go +++ b/internal/examples/supervisor/supervisor/supervisor_test.go @@ -4,7 +4,8 @@ import ( "fmt" "os" "testing" - + "time" + "github.com/stretchr/testify/assert" "github.com/open-telemetry/opamp-go/internal" @@ -62,3 +63,33 @@ agent: supervisor.Shutdown() } + +func TestShutdownRaceCondition(t *testing.T) { + tmpDir := changeCurrentDir(t) + os.WriteFile("supervisor.yaml", []byte(fmt.Sprintf(` +server: + endpoint: ws://127.0.0.1:4320/v1/opamp +agent: + executable: %s/dummy_agent.sh`, tmpDir)), 0644) + + os.WriteFile("dummy_agent.sh", []byte("#!/bin/sh\nsleep 9999\n"), 0755) + + startOpampServer(t) + + // There's no great way to ensure Shutdown gets called before Start. + // The DelayLogger ensures some delay before the goroutine gets started. + var supervisor *Supervisor + var err error + supervisor, err = NewSupervisor(&internal.DelayLogger{}) + supervisor.Shutdown() + supervisor.hasNewConfig <- struct{}{} + + assert.NoError(t, err) + + // The Shutdown method has been called before the runAgentProcess goroutine + // gets started and has a chance to load a new process. Make sure no PID + // has been launched. + assert.Never(t, func() bool { + return supervisor.commander.Pid() != 0 + }, 2*time.Second, 10*time.Millisecond) +} diff --git a/internal/noplogger.go b/internal/noplogger.go index a2b2ea2..9807e6e 100644 --- a/internal/noplogger.go +++ b/internal/noplogger.go @@ -2,6 +2,7 @@ package internal import ( "context" + "time" "github.com/open-telemetry/opamp-go/client/types" ) @@ -12,3 +13,12 @@ type NopLogger struct{} func (l *NopLogger) Debugf(ctx context.Context, format string, v ...interface{}) {} func (l *NopLogger) Errorf(ctx context.Context, format string, v ...interface{}) {} + +type DelayLogger struct{} + +func (l *DelayLogger) Debugf(ctx context.Context, format string, v ...interface{}) { + time.Sleep(10 * time.Millisecond) +} +func (l *DelayLogger) Errorf(ctx context.Context, format string, v ...interface{}) { + time.Sleep(10 * time.Millisecond) +}