From 56f3a0d02e0d9c6957fa795b3b37d9c87cdbe172 Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Wed, 13 May 2026 01:17:36 +0300 Subject: [PATCH 1/3] Propagate task context --- workerpool.go | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/workerpool.go b/workerpool.go index c97dd57..7ef8f03 100644 --- a/workerpool.go +++ b/workerpool.go @@ -12,6 +12,11 @@ type Task interface { Do(context.Context) } +type entry[T Task] struct { + ctx context.Context + job T +} + // Option configures a Pool. type Option[T Task] func(*Pool[T]) @@ -24,9 +29,9 @@ func WithBuffer[T Task](size int) Option[T] { // Pool maintains fixed worker goroutines processing tasks from a channel. type Pool[T Task] struct { - tasks chan T // channel for tasks waiting to be processed - buffer int // size of the task channel - wg sync.WaitGroup // wait group for worker goroutines + entries chan entry[T] // channel for jobs waiting to be processed + buffer int // size of the task channel + wg sync.WaitGroup // wait group for worker goroutines // immediate termination ctx context.Context @@ -38,18 +43,18 @@ type Pool[T Task] struct { shutdownOnce sync.Once } -// New creates a pool with numOfWorkers workers. +// New creates a pool with number of available workers. // The context can be used to stop the pool immediately, skipping any buffered // tasks. In-flight tasks will still run to completion. -func New[T Task](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool[T] { - if numOfWorkers <= 0 { - numOfWorkers = 1 +func New[T Task](ctx context.Context, workers int, opts ...Option[T]) *Pool[T] { + if workers <= 0 { + workers = 1 } - ctx, cancel := context.WithCancel(ctx) + poolCtx, cancel := context.WithCancel(ctx) p := &Pool[T]{ - ctx: ctx, + ctx: poolCtx, cancel: cancel, stop: make(chan struct{}), } @@ -58,10 +63,10 @@ func New[T Task](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool opt(p) } - p.tasks = make(chan T, p.buffer) + p.entries = make(chan entry[T], p.buffer) - p.wg.Add(numOfWorkers) - for range numOfWorkers { + p.wg.Add(workers) + for range workers { go p.worker() } return p @@ -79,27 +84,29 @@ func (p *Pool[T]) worker() { // drain remaining buffered tasks before exiting for { select { - case task := <-p.tasks: - task.Do(p.ctx) + case entry := <-p.entries: + entry.job.Do(entry.ctx) default: return } } - case task := <-p.tasks: - task.Do(p.ctx) + case entry := <-p.entries: + entry.job.Do(entry.ctx) } } } // Submit sends a task to the pool. Blocks if the task channel is full. // Returns false if the pool is shutting down or the context was cancelled. -func (p *Pool[T]) Submit(task T) bool { +func (p *Pool[T]) Submit(ctx context.Context, task T) bool { select { + case <-ctx.Done(): + return false case <-p.ctx.Done(): // forcefully terminate via ctx return false case <-p.stop: // terminated via graceful shutdown return false - case p.tasks <- task: + case p.entries <- entry[T]{ctx: ctx, job: task}: return true } } From 9fac374a80556eca9825b56e43817aaf4cf88bfd Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Wed, 13 May 2026 01:18:41 +0300 Subject: [PATCH 2/3] Unit test task ctx propagation and organize tests --- workerpool_test.go | 534 +++++++++++++++++++++++++-------------------- 1 file changed, 294 insertions(+), 240 deletions(-) diff --git a/workerpool_test.go b/workerpool_test.go index 87ca323..acde68b 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -11,371 +11,425 @@ import ( type testTask struct { value int counter *atomic.Int32 - fn func() + fn func(context.Context) } func (t testTask) Do(ctx context.Context) { if t.fn != nil { - t.fn() + t.fn(ctx) } t.counter.Add(int32(t.value)) } -func TestPool_Submit(t *testing.T) { +func TestNew(t *testing.T) { t.Parallel() - pool := New[testTask](context.TODO(), 3) + t.Run("defaults to a single worker when size is zero", func(t *testing.T) { + t.Parallel() + + pool := New[testTask](context.TODO(), 0) // should default to 1 worker + + var counter atomic.Int32 + + task := testTask{ + value: 42, + counter: &counter, + } + + if !pool.Submit(context.TODO(), task) { + t.Error("failed to submit task") + } + + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } + + if got := counter.Load(); got != 42 { + t.Errorf("expected counter = 42, got %d", got) + } + }) +} + +func TestPool_WithBuffer(t *testing.T) { + t.Parallel() + + pool := New(context.TODO(), 3, WithBuffer[testTask](10)) var counter atomic.Int32 - // submit tasks - for i := 1; i <= 10; i++ { + // submit 20 tasks + for i := 1; i <= 20; i++ { task := testTask{ value: i, counter: &counter, } - if !pool.Submit(task) { + if !pool.Submit(context.TODO(), task) { t.Error("failed to submit task") } } if err := pool.GracefulShutdown(); err != nil { - t.Errorf("shutdown error: %v", err) + t.Errorf("unexpected shutdown error: %v", err) } - // sum of 1..10 = 55 - if got := counter.Load(); got != 55 { - t.Errorf("expected counter = 55, got %d", got) + // sum of 1..20 = 210 + if got := counter.Load(); got != 210 { + t.Errorf("expected counter = 210, got %d", got) } } -func TestPool_GracefulShutdown(t *testing.T) { +func TestPool_Submit(t *testing.T) { t.Parallel() - t.Run("Submit task after graceful shutdown", func(t *testing.T) { + t.Run("queues and executes tasks", func(t *testing.T) { t.Parallel() - pool := New[testTask](context.TODO(), 2) - - // immediately shutdown the pool - if err := pool.GracefulShutdown(); err != nil { - t.Errorf("shutdown error: %v", err) - } + pool := New[testTask](context.TODO(), 3) var counter atomic.Int32 - task := testTask{ - value: 1, - counter: &counter, + // submit tasks + for i := 1; i <= 10; i++ { + task := testTask{ + value: i, + counter: &counter, + } + + if !pool.Submit(context.TODO(), task) { + t.Error("failed to submit task") + } } - if pool.Submit(task) { - t.Error("expected Submit to return false after shutdown") + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("shutdown error: %v", err) + } + + // sum of 1..10 = 55 + if got := counter.Load(); got != 55 { + t.Errorf("expected counter = 55, got %d", got) } }) - t.Run("Graceful shutdown waits for all queued tasks to be complete", func(t *testing.T) { + t.Run("second submit blocks when the worker is busy”", func(t *testing.T) { t.Parallel() - // use a single worker and a buffered channel - // - // block the worker with a task that won't complete until - // we say so, then fill the buffer with additional tasks, - // call GracefulShutdown while the worker is still blocked, - // then unblock it and verify that GracefulShutdown only returns - // after all buffered tasks have been processed (not just the in-flight one) + synctest.Test(t, func(t *testing.T) { + pool := New[testTask](t.Context(), 1) - const buffered, workers = 5, 1 + counter := &atomic.Int32{} + blocker := make(chan struct{}) - pool := New(context.TODO(), workers, WithBuffer[testTask](buffered)) + go func() { + pool.Submit(context.TODO(), testTask{ + fn: func(_ context.Context) { + <-blocker + }, + value: 1, + counter: counter, + }) + }() - var counter atomic.Int32 - blocker := make(chan struct{}) - started := make(chan struct{}) - - // submit a task that signals when it starts and then blocks, - // keeping the worker occupied while we fill the buffer - pool.Submit(testTask{ - fn: func() { - close(started) - <-blocker - }, - value: 1, - counter: &counter, - }) + synctest.Wait() - <-started // worker is now blocked + // second Submit must block since channel is unbuffered and worker is busy + task2Submitted := make(chan struct{}) + go func() { + pool.Submit(context.TODO(), testTask{ + value: 2, + counter: counter, + }) + close(task2Submitted) + }() + + // task2 goroutine is durably blocked on the channel send + synctest.Wait() - // fill the buffer while the worker is blocked - for range buffered { - pool.Submit(testTask{ - value: 1, counter: &counter, - }) - } + select { + case <-task2Submitted: + t.Error("submit should be blocked on unbuffered channel") + default: + // task2Submitted is not yet closed. Submit is still blocked + } - // call GracefulShutdown before unblocking - // it must not return until the buffer is fully drained - shutdownDone := make(chan error, 1) - go func() { - shutdownDone <- pool.GracefulShutdown() - }() + close(blocker) // release the worker - close(blocker) // release the worker to start draining + synctest.Wait() // worker processes both tasks and task2 goroutine exits - if err := <-shutdownDone; err != nil { - t.Errorf("unexpected shutdown error: %v", err) - } + select { + case <-task2Submitted: + // task2Submitted is closed and Submit completed after worker was unblocked, as expected + default: + t.Error("Submit should have completed after worker was unblocked") + } - // buffered+1 (started+buffered tasks) - if got := counter.Load(); got != buffered+1 { - t.Errorf("expected counter = %d, got %d", buffered+1, got) - } + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } + + if got := counter.Load(); got != 3 { + t.Errorf("expected counter = 3, got %d", got) + } + }) }) - t.Run("Context cancellation terminates pool without draining buffered tasks", func(t *testing.T) { + t.Run("task observes its own context cancellation via Do", func(t *testing.T) { t.Parallel() synctest.Test(t, func(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) + taskCtx, cancelTask := context.WithCancel(t.Context()) - pool := New[testTask](ctx, 1) + pool := New[testTask](t.Context(), 1) var counter atomic.Int32 - blocker := make(chan struct{}) - started := make(chan struct{}) + taskDone := make(chan struct{}) - pool.Submit(testTask{ - fn: func() { - close(started) - <-blocker + pool.Submit(taskCtx, testTask{ + fn: func(ctx context.Context) { + <-ctx.Done() // block + close(taskDone) }, value: 1, counter: &counter, }) - <-started // worker is blocked - cancel() // cancel the caller's context - close(blocker) // unblock the worker + synctest.Wait() // worker is now blocked + cancelTask() + // wait until the <-ctx.Done() is complete (so it can close taskDone) synctest.Wait() - if err := pool.GracefulShutdown(); err == nil { - t.Error("expected GracefulShutdown to return an error after context cancellation") + select { + case <-taskDone: + // task observed its ctx cancellation + default: + t.Error("task should have observed context cancellation") + } + + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) } }) }) -} - -func TestPool_Concurrency(t *testing.T) { - const workers, tasks = 5, 100 - pool := New[testTask](context.TODO(), workers) + t.Run("handles concurrent submissions", func(t *testing.T) { + t.Parallel() - var ( - counter atomic.Int32 - wg sync.WaitGroup - ) + const workers, tasks = 5, 100 - wg.Add(tasks) + pool := New[testTask](context.TODO(), workers) - for i := range tasks { - go func(val int) { - defer wg.Done() + var ( + counter atomic.Int32 + wg sync.WaitGroup + ) - task := testTask{ - value: 1, - counter: &counter, - } - pool.Submit(task) - }(i) - } + wg.Add(tasks) - wg.Wait() + for i := range tasks { + go func(val int) { + defer wg.Done() - if err := pool.GracefulShutdown(); err != nil { - t.Errorf("unexpected shutdown error: %v", err) - } - - if got := counter.Load(); got != tasks { - t.Errorf("expected counter = %d, got %d", tasks, got) - } -} + task := testTask{ + value: 1, + counter: &counter, + } + pool.Submit(context.TODO(), task) + }(i) + } -func TestPool_ZeroSize(t *testing.T) { - pool := New[testTask](context.TODO(), 0) // should default to 1 worker + wg.Wait() - var counter atomic.Int32 + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } - task := testTask{ - value: 42, - counter: &counter, - } + if got := counter.Load(); got != tasks { + t.Errorf("expected counter = %d, got %d", tasks, got) + } + }) - if !pool.Submit(task) { - t.Error("failed to submit task") - } + t.Run("blocks caller until capacity is available", func(t *testing.T) { + t.Parallel() - if err := pool.GracefulShutdown(); err != nil { - t.Errorf("unexpected shutdown error: %v", err) - } + synctest.Test(t, func(t *testing.T) { + pool := New[testTask](context.TODO(), 1) - if got := counter.Load(); got != 42 { - t.Errorf("expected counter = 42, got %d", got) - } -} + var counter atomic.Int32 + blocker := make(chan struct{}) -func TestPool_Backpressure(t *testing.T) { - t.Parallel() + // submit blocking task + task1 := testTask{ + fn: func(_ context.Context) { + <-blocker // block here + }, + value: 1, + counter: &counter, + } - synctest.Test(t, func(t *testing.T) { - pool := New[testTask](context.TODO(), 1) + go pool.Submit(context.TODO(), task1) - var counter atomic.Int32 - blocker := make(chan struct{}) + synctest.Wait() // let worker pick up task1 - // submit blocking task - task1 := testTask{ - fn: func() { - <-blocker // block here - }, - value: 1, - counter: &counter, - } + // try to submit second task should block since channel is unbuffered + task2Submitted := make(chan struct{}) - go pool.Submit(task1) + go func() { + pool.Submit(context.TODO(), testTask{ + value: 2, + counter: &counter, + }) + close(task2Submitted) + }() - synctest.Wait() // let worker pick up task1 + synctest.Wait() // worker still blocked - // try to submit second task should block since channel is unbuffered - task2Submitted := make(chan struct{}) + select { + case <-task2Submitted: + t.Error("Submit should be blocked while worker is busy") + default: + // expected + } - go func() { - pool.Submit(testTask{ - value: 2, - counter: &counter, - }) - close(task2Submitted) - }() - - // worker still blocked on <-blocker - // task2 goroutine durably blocked on channel send - synctest.Wait() - select { - case <-task2Submitted: - t.Error("Submit should be blocked while worker is busy") - default: - // expected - } + close(blocker) // release the worker + synctest.Wait() // worker processes both tasks; task2 goroutine exits - close(blocker) // release the worker - synctest.Wait() // worker processes both tasks; task2 goroutine exits - select { - case <-task2Submitted: - // expected: Submit completed - default: - t.Error("Submit should have completed after worker was unblocked") - } + select { + case <-task2Submitted: + // expected: Submit completed + default: + t.Error("Submit should have completed after worker was unblocked") + } - if err := pool.GracefulShutdown(); err != nil { - t.Errorf("unexpected shutdown error: %v", err) - } + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } - if got := counter.Load(); got != 3 { - t.Errorf("expected counter = 3, got %d", got) - } + if got := counter.Load(); got != 3 { + t.Errorf("expected counter = 3, got %d", got) + } + }) }) } -func TestPool_WithBuffer(t *testing.T) { +func TestPool_GracefulShutdown(t *testing.T) { t.Parallel() - pool := New(context.TODO(), 3, WithBuffer[testTask](10)) + t.Run("Submit task after graceful shutdown", func(t *testing.T) { + t.Parallel() - var counter atomic.Int32 + pool := New[testTask](context.TODO(), 2) + + // immediately shutdown the pool + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("shutdown error: %v", err) + } + + var counter atomic.Int32 - // submit 20 tasks - for i := 1; i <= 20; i++ { task := testTask{ - value: i, + value: 1, counter: &counter, } - if !pool.Submit(task) { - t.Error("failed to submit task") + if pool.Submit(context.TODO(), task) { + t.Error("expected Submit to return false after shutdown") } - } + }) - if err := pool.GracefulShutdown(); err != nil { - t.Errorf("unexpected shutdown error: %v", err) - } + t.Run("Graceful shutdown waits for all queued tasks to be complete", func(t *testing.T) { + t.Parallel() - // sum of 1..20 = 210 - if got := counter.Load(); got != 210 { - t.Errorf("expected counter = 210, got %d", got) - } -} + // use a single worker and a buffered channel + // + // block the worker with a task that won't complete until + // we say so, then fill the buffer with additional tasks, + // call GracefulShutdown while the worker is still blocked, + // then unblock it and verify that GracefulShutdown only returns + // after all buffered tasks have been processed (not just the in-flight one) -func TestPool_UnbufferedDefault(t *testing.T) { - t.Parallel() + synctest.Test(t, func(t *testing.T) { + const buffered, workers = 5, 1 - synctest.Test(t, func(t *testing.T) { - pool := New[testTask](t.Context(), 1) + pool := New(context.TODO(), workers, WithBuffer[testTask](buffered)) - counter := &atomic.Int32{} - blocker := make(chan struct{}) - started := make(chan struct{}) + var counter atomic.Int32 + blocker := make(chan struct{}) - go func() { - pool.Submit(testTask{ - fn: func() { - close(started) + pool.Submit(context.TODO(), testTask{ + fn: func(_ context.Context) { <-blocker }, value: 1, - counter: counter, + counter: &counter, }) - }() - <-started // worker has picked up task1 and is blocked + synctest.Wait() - // second Submit must block since channel is unbuffered and worker is busy - task2Submitted := make(chan struct{}) - go func() { - pool.Submit(testTask{ - value: 2, - counter: counter, - }) - close(task2Submitted) - }() + // fill the buffer while the worker is blocked + for range buffered { + pool.Submit(context.TODO(), testTask{ + value: 1, counter: &counter, + }) + } - // task2 goroutine is durably blocked on the channel send - synctest.Wait() + // call GracefulShutdown before unblocking + // it must not return until the buffer is fully drained - select { - case <-task2Submitted: - t.Error("submit should be blocked on unbuffered channel") - default: - // task2Submitted is not yet closed. Submit is still blocked - } + shutdownDone := make(chan error, 1) + go func() { + shutdownDone <- pool.GracefulShutdown() + }() - close(blocker) // release the worker + close(blocker) // release the worker to start draining - synctest.Wait() // worker processes both tasks and task2 goroutine exits + if err := <-shutdownDone; err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } - select { - case <-task2Submitted: - // task2Submitted is closed and Submit completed after worker was unblocked, as expected - default: - t.Error("Submit should have completed after worker was unblocked") - } + // buffered+1 (started+buffered) + if got := counter.Load(); got != buffered+1 { + t.Errorf("expected counter = %d, got %d", buffered+1, got) + } + }) + }) - if err := pool.GracefulShutdown(); err != nil { - t.Errorf("unexpected shutdown error: %v", err) - } + t.Run("task observes context cancellation", func(t *testing.T) { + t.Parallel() - if got := counter.Load(); got != 3 { - t.Errorf("expected counter = 3, got %d", got) - } + synctest.Test(t, func(t *testing.T) { + workers := 1 + pool := New[testTask](t.Context(), workers) + + var counter atomic.Int32 + + observed := make(chan struct{}) + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + pool.Submit(ctx, testTask{ + fn: func(ctx context.Context) { + <-ctx.Done() + close(observed) + }, + value: 1, + counter: &counter, + }) + + cancel() + synctest.Wait() + + select { + case <-observed: + // task observed its cancellation + default: + t.Error("task should have observed cancelled ctx") + } + + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } + }) }) } From 9118cfdc13ebdffc651cfcb3a567abf0eeec0d50 Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Wed, 13 May 2026 01:18:53 +0300 Subject: [PATCH 3/3] Update example and readme --- README.md | 39 ++++++++++++++++++--------------------- example_test.go | 4 ++-- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 8cf3bfc..6193e4f 100644 --- a/README.md +++ b/README.md @@ -21,33 +21,30 @@ type bazooka struct { bodyCount *atomic.Int32 } -// Do simulate some bazooking, and -// implement the generic constraint -func (b bazooka) Do(ctx context.Context) { +// Do simulate some bazooking +func (b *bazooka) Do(ctx context.Context) { b.ammo-- fmt.Fprintln(os.Stderr, "bazooking: "+b.targetID) b.bodyCount.Add(1) } -... -pool := workerpool.New[bazooka](3) -defer pool.Shutdown() +func main() { + pool := workerpool.New[*bazooka](context.TODO(), 3) -var bodyCount atomic.Int32 - -bazookas := []bazooka{ - {ammo: 69, targetID: "foo-id", bodyCount: &bodyCount}, - {ammo: 42, targetID: "bar-id", bodyCount: &bodyCount}, - {ammo: 11, targetID: "qux-id", bodyCount: &bodyCount}, -} - -for _, bazz := range bazookas { - task := workerpool.Task[bazooka]{ - Fn: func(input bazooka) { - input.Do(context.TODO()) - }, - Input: bazz, + var bodyCount atomic.Int32 + + bazookas := []bazooka{ + {ammo: 69, targetID: "foo-id", bodyCount: &bodyCount}, + {ammo: 42, targetID: "bar-id", bodyCount: &bodyCount}, + {ammo: 11, targetID: "qux-id", bodyCount: &bodyCount}, + } + + for i := range bazookas { + pool.Submit(context.TODO(), &bazookas[i]) } - pool.Submit(task) + + _ = pool.GracefulShutdown() + + fmt.Printf("Body count: %d\n", bodyCount.Load()) } ``` diff --git a/example_test.go b/example_test.go index 9376095..c3cf129 100644 --- a/example_test.go +++ b/example_test.go @@ -35,8 +35,8 @@ func Example() { {ammo: 11, targetID: "qux-id", bodyCount: &bodyCount}, } - for _, bazz := range bazookas { - pool.Submit(&bazz) + for i := range bazookas { + pool.Submit(context.TODO(), &bazookas[i]) } if err := pool.GracefulShutdown(); err != nil {