From 74477d6f47b29419caa531c8b1a5a23b68c8e6cc Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Tue, 12 May 2026 19:35:45 +0300 Subject: [PATCH 1/4] Add context-based pool termination Pool initialization now accepts a ctx for immediate termination: workers exit without draining buffered tasks, though any in-flight task runs to completion. The existing Shutdown method is renamed to GracefulShutdown, which drains all buffered tasks and waits for workers to finish. It returns an error if the ctx was cancelled before shutdown completed. --- workerpool.go | 62 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/workerpool.go b/workerpool.go index bab352c..2cc2ae6 100644 --- a/workerpool.go +++ b/workerpool.go @@ -2,7 +2,9 @@ package workerpool import ( "context" + "errors" "sync" + "sync/atomic" ) // Input wraps a task's execution. @@ -28,21 +30,34 @@ func WithBuffer[T Input](size int) Option[T] { // Pool maintains fixed worker goroutines processing tasks from a channel. type Pool[T Input] struct { - tasks chan Task[T] - wg sync.WaitGroup + tasks chan Task[T] // channel for tasks waiting to be processed + buffer int // size of the task channel + wg sync.WaitGroup // wait group for worker goroutines + + // immediate termination + ctx context.Context + cancel context.CancelFunc + ungracefulStop atomic.Bool + + // graceful shutdown stop chan struct{} shutdownOnce sync.Once - buffer int } -// New creates a pool with size workers. -func New[T Input](size int, opts ...Option[T]) *Pool[T] { - if size <= 0 { - size = 1 +// New creates a pool with numOfWorkers workers. +// The context can be used to stop the pool immediately, skipping any buffered +// tasks. In-flight tasks will still run to completion. +func New[T Input](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool[T] { + if numOfWorkers <= 0 { + numOfWorkers = 1 } + ctx, cancel := context.WithCancel(ctx) + p := &Pool[T]{ - stop: make(chan struct{}), + ctx: ctx, + cancel: cancel, + stop: make(chan struct{}), } for _, opt := range opts { @@ -51,8 +66,8 @@ func New[T Input](size int, opts ...Option[T]) *Pool[T] { p.tasks = make(chan Task[T], p.buffer) - p.wg.Add(size) - for range size { + p.wg.Add(numOfWorkers) + for range numOfWorkers { go p.worker() } return p @@ -62,8 +77,12 @@ func (p *Pool[T]) worker() { defer p.wg.Done() for { select { + case <-p.ctx.Done(): + // exit without draining buffered tasks + p.ungracefulStop.Store(true) + return case <-p.stop: - // drain remaining buffered tasks + // drain remaining buffered tasks before exiting for { select { case task := <-p.tasks: @@ -78,21 +97,32 @@ func (p *Pool[T]) worker() { } } -// Submit sends a task to the pool. Blocks if all workers are busy. -// Returns false if pool is shut down. +// Submit sends a task to the pool. Blocks if the task channel is full. +// Returns false if the pool is shutting down or the context was cancelled. func (p *Pool[T]) Submit(task Task[T]) bool { select { - case <-p.stop: + case <-p.ctx.Done(): // forcefully terminate via ctx + return false + case <-p.stop: // terminated via graceful shutdown return false case p.tasks <- task: return true } } -// Shutdown stops accepting tasks and waits for active tasks to complete. -func (p *Pool[T]) Shutdown() { +// GracefulShutdown stops accepting new tasks, drains all buffered tasks, +// and waits for in-flight tasks to complete before returning. +// Returns an error if the ctx was cancelled before shutdown completed. +func (p *Pool[T]) GracefulShutdown() error { p.shutdownOnce.Do(func() { close(p.stop) }) + p.wg.Wait() + p.cancel() + + if p.ungracefulStop.Load() { + return errors.New("pool was forcefully terminated before shutdown") + } + return nil } From 6218b7816254a5ac23bf178058b85b3d004559a2 Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Tue, 12 May 2026 19:38:08 +0300 Subject: [PATCH 2/4] Add ctx (for cancel), comments and fix bazooka pointer for correctness --- example_test.go | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/example_test.go b/example_test.go index 6f5c434..d020ab5 100644 --- a/example_test.go +++ b/example_test.go @@ -16,35 +16,42 @@ type bazooka struct { } // Do simulate some bazooking -func (b bazooka) Do(ctx context.Context) { +func (b *bazooka) Do(ctx context.Context) { b.ammo-- fmt.Fprintln(os.Stderr, "bazooking: "+b.targetID) b.bodyCount.Add(1) } func Example() { - pool := workerpool.New[bazooka](3) - defer pool.Shutdown() + // starts a pool with 3 workers + // use the context to cancel the pool without waiting for buffered tasks to complete + pool := workerpool.New[*bazooka](context.TODO(), 3) var bodyCount atomic.Int32 + // list of tasks to perform bazookas := []bazooka{ {ammo: 69, targetID: "foo-id", bodyCount: &bodyCount}, {ammo: 42, targetID: "bar-id", bodyCount: &bodyCount}, {ammo: 11, targetID: "qux-id", bodyCount: &bodyCount}, } + // loop over the tasks initializing and + // submitting the tasks to the pool + for _, bazz := range bazookas { - task := workerpool.Task[bazooka]{ - Fn: func(input bazooka) { + task := workerpool.Task[*bazooka]{ + Fn: func(input *bazooka) { input.Do(context.TODO()) }, - Input: bazz, + Input: &bazz, } pool.Submit(task) } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + fmt.Fprintf(os.Stderr, "shutdown error: %v\n", err) + } fmt.Printf("Body count: %d\n", bodyCount.Load()) // Output: From ebb4f2de2d54572a4a5c260237bc12fed9e7b0cd Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Tue, 12 May 2026 21:44:37 +0300 Subject: [PATCH 3/4] Update test for reliability and coverage Also use sync test to avoid waiting --- workerpool_test.go | 331 ++++++++++++++++++++++++++++----------------- 1 file changed, 208 insertions(+), 123 deletions(-) diff --git a/workerpool_test.go b/workerpool_test.go index 2f56e7d..2f93f8b 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -5,7 +5,7 @@ import ( "sync" "sync/atomic" "testing" - "time" + "testing/synctest" ) type testInput struct { @@ -13,15 +13,12 @@ type testInput struct { counter *atomic.Int32 } -func (t testInput) Do(ctx context.Context) { - t.counter.Add(int32(t.value)) -} +func (t testInput) Do(ctx context.Context) { t.counter.Add(int32(t.value)) } func TestPool_Submit(t *testing.T) { t.Parallel() - pool := New[testInput](3) - defer pool.Shutdown() + pool := New[testInput](context.TODO(), 3) var counter atomic.Int32 @@ -38,7 +35,9 @@ func TestPool_Submit(t *testing.T) { } } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("shutdown error: %v", err) + } // sum of 1..10 = 55 if got := counter.Load(); got != 55 { @@ -46,34 +45,99 @@ func TestPool_Submit(t *testing.T) { } } -func TestPool_SubmitAfterShutdown(t *testing.T) { +func TestPool_GracefulShutdown(t *testing.T) { t.Parallel() - pool := New[testInput](2) - pool.Shutdown() + t.Run("Submit task after graceful shutdown", func(t *testing.T) { + t.Parallel() - var counter atomic.Int32 + pool := New[testInput](context.TODO(), 2) - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: &counter}, - } + // immediately shutdown the pool + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("shutdown error: %v", err) + } - if pool.Submit(task) { - t.Error("expected Submit to return false after shutdown") - } + var counter atomic.Int32 + + task := Task[testInput]{ + Fn: func(input testInput) { + input.Do(context.TODO()) + }, + Input: testInput{value: 1, counter: &counter}, + } + + if pool.Submit(task) { + t.Error("expected Submit to return false after shutdown") + } + }) + + t.Run("Graceful shutdown waits for all queued tasks to be complete", func(t *testing.T) { + t.Parallel() + + // use a single worker and a buffered channel + // + // block the worker with a task that won't complete until + // we say so, then fill the buffer with additional tasks, + // call GracefulShutdown while the worker is still blocked, + // then unblock it and verify that GracefulShutdown only returns + // after all buffered tasks have been processed (not just the in-flight one) + + const buffered, workers = 5, 1 + + pool := New(context.TODO(), workers, WithBuffer[testInput](buffered)) + + var counter atomic.Int32 + blocker := make(chan struct{}) + started := make(chan struct{}) + + // submit a task that signals when it starts and then blocks, + // keeping the worker occupied while we fill the buffer + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + close(started) + <-blocker + input.Do(context.TODO()) + }, + Input: testInput{value: 1, counter: &counter}, + }) + + <-started // worker is now blocked + + // fill the buffer while the worker is blocked + for range buffered { + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + input.Do(context.TODO()) + }, + Input: testInput{value: 1, counter: &counter}, + }) + } + + // call GracefulShutdown before unblocking + // it must not return until the buffer is fully drained + shutdownDone := make(chan error, 1) + go func() { + shutdownDone <- pool.GracefulShutdown() + }() + + close(blocker) // release the worker to start draining + + if err := <-shutdownDone; err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } + + // buffered+1 (started+buffered tasks) + if got := counter.Load(); got != buffered+1 { + t.Errorf("expected counter = %d, got %d", buffered+1, got) + } + }) } func TestPool_Concurrency(t *testing.T) { - const ( - workers = 5 - tasks = 100 - ) + const workers, tasks = 5, 100 - pool := New[testInput](workers) - defer pool.Shutdown() + pool := New[testInput](context.TODO(), workers) var ( counter atomic.Int32 @@ -81,9 +145,11 @@ func TestPool_Concurrency(t *testing.T) { ) wg.Add(tasks) + for i := range tasks { go func(val int) { defer wg.Done() + task := Task[testInput]{ Fn: func(input testInput) { input.Do(context.TODO()) @@ -95,7 +161,10 @@ func TestPool_Concurrency(t *testing.T) { } wg.Wait() - pool.Shutdown() + + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } if got := counter.Load(); got != tasks { t.Errorf("expected counter = %d, got %d", tasks, got) @@ -103,8 +172,7 @@ func TestPool_Concurrency(t *testing.T) { } func TestPool_ZeroSize(t *testing.T) { - pool := New[testInput](0) // should default to 1 worker - defer pool.Shutdown() + pool := New[testInput](context.TODO(), 0) // should default to 1 worker var counter atomic.Int32 @@ -119,7 +187,9 @@ func TestPool_ZeroSize(t *testing.T) { t.Error("failed to submit task") } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } if got := counter.Load(); got != 42 { t.Errorf("expected counter = 42, got %d", got) @@ -127,69 +197,73 @@ func TestPool_ZeroSize(t *testing.T) { } func TestPool_Backpressure(t *testing.T) { - pool := New[testInput](1) - defer pool.Shutdown() + t.Parallel() - var counter atomic.Int32 - blocker := make(chan struct{}) + synctest.Test(t, func(t *testing.T) { + pool := New[testInput](context.TODO(), 1) - // submit blocking task - task1 := Task[testInput]{ - Fn: func(input testInput) { - <-blocker // block here - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: &counter}, - } + var counter atomic.Int32 + blocker := make(chan struct{}) - go pool.Submit(task1) - time.Sleep(100 * time.Millisecond) // let worker pick up task1 + // submit blocking task + task1 := Task[testInput]{ + Fn: func(input testInput) { + <-blocker // block here + input.Do(context.TODO()) + }, + Input: testInput{value: 1, counter: &counter}, + } - // try to submit second task should block since channel is unbuffered - submitted := make(chan bool) - task2 := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 2, counter: &counter}, - } + go pool.Submit(task1) - go func() { - submitted <- pool.Submit(task2) - }() + synctest.Wait() // let worker pick up task1 - // verify Submit is blocked - select { - case <-submitted: - t.Error("Submit should be blocked") - case <-time.After(50 * time.Millisecond): - // expected Submit is blocked - } + // try to submit second task should block since channel is unbuffered + task2Submitted := make(chan struct{}) - close(blocker) // unblock worker + go func() { + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + input.Do(context.Background()) + }, + Input: testInput{value: 2, counter: &counter}, + }) + close(task2Submitted) + }() + + // worker still blocked on <-blocker + // task2 goroutine durably blocked on channel send + synctest.Wait() + select { + case <-task2Submitted: + t.Error("Submit should be blocked while worker is busy") + default: + // expected + } - // now Submit should complete - select { - case result := <-submitted: - if !result { - t.Error("expected Submit to succeed") + close(blocker) // release the worker + synctest.Wait() // worker processes both tasks; task2 goroutine exits + select { + case <-task2Submitted: + // expected: Submit completed + default: + t.Error("Submit should have completed after worker was unblocked") } - case <-time.After(100 * time.Millisecond): - t.Error("Submit should have completed") - } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } - if got := counter.Load(); got != 3 { - t.Errorf("expected counter = 3, got %d", got) - } + if got := counter.Load(); got != 3 { + t.Errorf("expected counter = 3, got %d", got) + } + }) } func TestPool_WithBuffer(t *testing.T) { t.Parallel() - pool := New(3, WithBuffer[testInput](10)) - defer pool.Shutdown() + pool := New(context.TODO(), 3, WithBuffer[testInput](10)) var counter atomic.Int32 @@ -206,7 +280,9 @@ func TestPool_WithBuffer(t *testing.T) { } } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } // sum of 1..20 = 210 if got := counter.Load(); got != 210 { @@ -215,58 +291,67 @@ func TestPool_WithBuffer(t *testing.T) { } func TestPool_UnbufferedDefault(t *testing.T) { - // without WithBuffer option, channel should be unbuffered - pool := New[testInput](1) - defer pool.Shutdown() + t.Parallel() - counter := &atomic.Int32{} - blocker := make(chan struct{}) - started := make(chan struct{}) + synctest.Test(t, func(t *testing.T) { + pool := New[testInput](t.Context(), 1) - // submit blocking task - go func() { - task := Task[testInput]{ - Fn: func(input testInput) { - close(started) - <-blocker - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: counter}, - } - pool.Submit(task) - }() + counter := &atomic.Int32{} + blocker := make(chan struct{}) + started := make(chan struct{}) - <-started + go func() { + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + close(started) + <-blocker + input.Do(context.Background()) + }, + Input: testInput{value: 1, counter: counter}, + }) + }() - // try to submit another should block since unbuffered - submitted := make(chan bool) - go func() { - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 2, counter: counter}, + <-started // worker has picked up task1 and is blocked + + // second Submit must block since channel is unbuffered and worker is busy + task2Submitted := make(chan struct{}) + go func() { + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + input.Do(context.Background()) + }, + Input: testInput{value: 2, counter: counter}, + }) + close(task2Submitted) + }() + + // task2 goroutine is durably blocked on the channel send + synctest.Wait() + + select { + case <-task2Submitted: + t.Error("submit should be blocked on unbuffered channel") + default: + // task2Submitted is not yet closed. Submit is still blocked } - pool.Submit(task) - submitted <- true - }() - - // verify submit is blocked - select { - case <-submitted: - t.Error("submit should be blocked on unbuffered channel") - case <-time.After(50 * time.Millisecond): - // expected - } - close(blocker) + close(blocker) // release the worker - // wait for second submit to complete - <-submitted + synctest.Wait() // worker processes both tasks and task2 goroutine exits - pool.Shutdown() + select { + case <-task2Submitted: + // task2Submitted is closed and Submit completed after worker was unblocked, as expected + default: + t.Error("Submit should have completed after worker was unblocked") + } - if got := counter.Load(); got != 3 { - t.Errorf("expected counter = 3, got %d", got) - } + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } + + if got := counter.Load(); got != 3 { + t.Errorf("expected counter = 3, got %d", got) + } + }) } From 469e817073aaec883dc8512a1ca8bf5351efe78f Mon Sep 17 00:00:00 2001 From: Alessandro Resta Date: Tue, 12 May 2026 21:45:45 +0300 Subject: [PATCH 4/4] Add Makefile with fmt, vet, lint, and test targets --- Makefile | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0d00e96 --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +.DEFAULT_GOAL := help + +PROJECT_NAME := workerpool + +.PHONY: help +help: + @echo "------------------------------------------------------------------------" + @echo "${PROJECT_NAME}" + @echo "------------------------------------------------------------------------" + @awk 'BEGIN {FS = ":.*?## "}; $$0 ~ "^[[:alnum:]_/%-]+:.*?## " {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) | sort + +.PHONY: fmt +fmt: ## Format code + go fmt ./... + +.PHONY: vet +vet: ## Vet code + go vet ./... + +.PHONY: test +test: ## Run unit tests + go test -short -race -count=1 -v ./... + +.PHONY: lint +lint: vet ## Lint code + @if command -v staticcheck >/dev/null 2>&1; then \ + staticcheck ./...; \ + else \ + echo "staticcheck not installed, skipping (go install honnef.co/go/tools/cmd/staticcheck@latest)"; \ + fi