From b121f38e2ecf4071bf1a91e37c57be88db4b13ea Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Wed, 13 May 2026 13:32:14 +0300 Subject: [PATCH] Check ungraceful stop before initiating graceful shutdown --- .gitignore | 1 + workerpool.go | 20 +++++++++++------ workerpool_test.go | 55 +++++++++++++++++++++++++++++++++------------- 3 files changed, 54 insertions(+), 22 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6350e98 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.coverage diff --git a/workerpool.go b/workerpool.go index 7ef8f03..769250f 100644 --- a/workerpool.go +++ b/workerpool.go @@ -87,6 +87,8 @@ func (p *Pool[T]) worker() { case entry := <-p.entries: entry.job.Do(entry.ctx) default: + // channel is empty. since p.stop is closed, + // no more tasks can be submitted return } } @@ -115,15 +117,19 @@ func (p *Pool[T]) Submit(ctx context.Context, task T) bool { // and waits for in-flight tasks to complete before returning. // Returns an error if the ctx was cancelled before shutdown completed. func (p *Pool[T]) GracefulShutdown() error { - p.shutdownOnce.Do(func() { - close(p.stop) - }) - - p.wg.Wait() - p.cancel() - if p.ungracefulStop.Load() { return errors.New("pool was forcefully terminated before shutdown") } + + p.shutdownOnce.Do(func() { + close(p.stop) + p.wg.Wait() + p.cancel() + + // only close(p.entries) with a lock here and + // a read lock in Submit otherwise senders will panic =] + // but it's just a good to have, since p.stop is closed + // and submit already checks for that + }) return nil } diff --git a/workerpool_test.go b/workerpool_test.go index acde68b..0a06279 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -88,8 +88,6 @@ func TestPool_Submit(t *testing.T) { pool := New[testTask](context.TODO(), 3) var counter atomic.Int32 - - // submit tasks for i := 1; i <= 10; i++ { task := testTask{ value: i, @@ -318,23 +316,50 @@ func TestPool_GracefulShutdown(t *testing.T) { t.Run("Submit task after graceful shutdown", func(t *testing.T) { t.Parallel() - pool := New[testTask](context.TODO(), 2) + synctest.Test(t, func(t *testing.T) { + pool := New[testTask](context.TODO(), 2) - // immediately shutdown the pool - if err := pool.GracefulShutdown(); err != nil { - t.Errorf("shutdown error: %v", err) - } + var counter atomic.Int32 + blocker := make(chan struct{}) - var counter atomic.Int32 + task1 := testTask{ + value: 1, + counter: &counter, + fn: func(context.Context) { + <-blocker + }, + } - task := testTask{ - value: 1, - counter: &counter, - } + if !pool.Submit(context.TODO(), task1) { + t.Error("expected Submit to return true") + } - if pool.Submit(context.TODO(), task) { - t.Error("expected Submit to return false after shutdown") - } + // wait for worker to pick up the task and block + synctest.Wait() + + // immediately shutdown the pool + // will close(p.stop) but block on p.wg.Wait() + shutdownDone := make(chan error, 1) + go func() { + shutdownDone <- pool.GracefulShutdown() + }() + + synctest.Wait() // wait shutdown to propagate + + task2 := testTask{ + value: 1, + counter: &counter, + } + + if pool.Submit(context.TODO(), task2) { + t.Error("expected Submit to return false after shutdown") + } + close(blocker) + + if err := <-shutdownDone; err != nil { + t.Errorf("shutdown error: %v", err) + } + }) }) t.Run("Graceful shutdown waits for all queued tasks to be complete", func(t *testing.T) {