diff --git a/changelog/@unreleased/pr-142.v2.yml b/changelog/@unreleased/pr-142.v2.yml new file mode 100644 index 00000000..66bb1ff0 --- /dev/null +++ b/changelog/@unreleased/pr-142.v2.yml @@ -0,0 +1,5 @@ +type: improvement +improvement: + description: Add Drain() method to tcpjson.AsyncWriter + links: + - https://github.com/palantir/witchcraft-go-logging/pull/142 diff --git a/wlog/tcpjson/async_writer.go b/wlog/tcpjson/async_writer.go index c3dc6f13..5db7213e 100644 --- a/wlog/tcpjson/async_writer.go +++ b/wlog/tcpjson/async_writer.go @@ -15,6 +15,7 @@ package tcpjson import ( + "context" "io" "log" @@ -34,33 +35,54 @@ type asyncWriter struct { buffer chan []byte output io.Writer dropped gometrics.Counter + queued gometrics.Gauge stop chan struct{} } +type AsyncWriter interface { + io.WriteCloser + // Drain tries to gracefully drain the remaining buffered messages, + // blocking until the buffer is empty or the provided context is cancelled. + Drain(ctx context.Context) +} + // StartAsyncWriter creates a Writer whose Write method puts the submitted byte slice onto a channel. // In a separate goroutine, slices are pulled from the queue and written to the output writer. // The Close method stops the consumer goroutine and will cause future writes to fail. -func StartAsyncWriter(output io.Writer, registry metrics.Registry) io.WriteCloser { +func StartAsyncWriter(output io.Writer, registry metrics.Registry) AsyncWriter { droppedCounter := registry.Counter(asyncWriterDroppedCounter) buffer := make(chan []byte, asyncWriterBufferCapacity) stop := make(chan struct{}) + queued := registry.Gauge(asyncWriterBufferLenGauge) + w := &asyncWriter{buffer: buffer, output: output, dropped: droppedCounter, queued: queued, stop: stop} go func() { - gauge := registry.Gauge(asyncWriterBufferLenGauge) for { + // Ensure we stop when requested. Without the additional select, + // the loop could continue to run as long as there are items in the buffer. + select { + case <-stop: + return + default: + } + select { case item := <-buffer: - gauge.Update(int64(len(buffer))) - if _, err := output.Write(item); err != nil { - // TODO(bmoylan): consider re-enqueuing message so it can be attempted again, which risks a thundering herd without careful handling. - log.Printf("write failed: %s", werror.GenerateErrorString(err, false)) - droppedCounter.Inc(1) - } + w.write(item) case <-stop: return } } }() - return &asyncWriter{buffer: buffer, output: output, dropped: droppedCounter, stop: stop} + return w +} + +func (w *asyncWriter) write(item []byte) { + w.queued.Update(int64(len(w.buffer))) + if _, err := w.output.Write(item); err != nil { + // TODO(bmoylan): consider re-enqueuing message so it can be attempted again, which risks a thundering herd without careful handling. + log.Printf("write failed: %s", werror.GenerateErrorString(err, false)) + w.dropped.Inc(1) + } } func (w *asyncWriter) Write(b []byte) (int, error) { @@ -87,3 +109,17 @@ func (w *asyncWriter) Close() (err error) { close(w.stop) return nil } + +func (w *asyncWriter) Drain(ctx context.Context) { + for { + select { + case item := <-w.buffer: + w.write(item) + case <-ctx.Done(): + return + default: + // Nothing left in the buffer, time to return + return + } + } +} diff --git a/wlog/tcpjson/async_writer_test.go b/wlog/tcpjson/async_writer_test.go index b7144a1c..09f7f710 100644 --- a/wlog/tcpjson/async_writer_test.go +++ b/wlog/tcpjson/async_writer_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "strconv" + "strings" "sync/atomic" "testing" "time" @@ -123,9 +124,9 @@ func TestAsyncWriteWithSvc1log(t *testing.T) { } func TestDropsLogs(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - w := &blockingWriter{ctx: ctx} + writerCtx, unblock := context.WithCancel(context.Background()) + defer unblock() + w := &blockingWriter{ctx: writerCtx} registry := metrics.NewRootMetricsRegistry() asyncTCPWriter := StartAsyncWriter(w, registry) defer func() { @@ -137,7 +138,7 @@ func TestDropsLogs(t *testing.T) { } assert.Equal(t, asyncWriterBufferCapacity, len(asyncTCPWriter.(*asyncWriter).buffer), "expected buffer to be full") assert.Equal(t, int64(100), registry.Counter(asyncWriterDroppedCounter).Count(), "expected dropped counter to increment") - cancel() + unblock() time.Sleep(time.Second) assert.Equal(t, 0, len(asyncTCPWriter.(*asyncWriter).buffer), "expected buffer to empty") } @@ -157,13 +158,47 @@ func TestDropsLogsOnError(t *testing.T) { assert.Equal(t, int64(5), registry.Counter(asyncWriterDroppedCounter).Count(), "expected dropped counter to increment") } +func TestShutdownDrainsBuffer(t *testing.T) { + writerCtx, unblock := context.WithCancel(context.Background()) + defer unblock() + w := &blockingWriter{ctx: writerCtx} + registry := metrics.NewRootMetricsRegistry() + asyncTCPWriter := StartAsyncWriter(w, registry) + logger := svc1log.NewFromCreator(asyncTCPWriter, wlog.DebugLevel, wlog.NewJSONMarshalLoggerProvider().NewLeveledLogger) + for i := 0; i < 5; i++ { + logger.Info(strconv.Itoa(i)) + } + // Close the writer for new entries + _ = asyncTCPWriter.Close() + _, err := asyncTCPWriter.Write([]byte("too late")) + assert.EqualError(t, err, "write to closed asyncWriter") + + // At this point, we have 5 messages queued. Next we start Drain(), which drains the writer. + assert.Empty(t, w.buf.String()) + + go func() { + time.Sleep(10 * time.Millisecond) + unblock() + }() + shutdownStart := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // should be instant, this is just to catch bugs + defer cancel() + asyncTCPWriter.Drain(ctx) + // We set a 5s timeout but this _should_ run very fast, so make sure it was less than 1s. + assert.Less(t, time.Since(shutdownStart), time.Second, "expected shutdown to be fast") + + writtenLines := strings.Split(strings.TrimSpace(w.buf.String()), "\n") + assert.Len(t, writtenLines, 5) +} + type blockingWriter struct { ctx context.Context + buf bytes.Buffer } func (b *blockingWriter) Write(p []byte) (int, error) { <-b.ctx.Done() - return len(p), nil + return b.buf.Write(p) } type alwaysErrorWriter struct{}