diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0d00e96 --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +.DEFAULT_GOAL := help + +PROJECT_NAME := workerpool + +.PHONY: help +help: + @echo "------------------------------------------------------------------------" + @echo "${PROJECT_NAME}" + @echo "------------------------------------------------------------------------" + @awk 'BEGIN {FS = ":.*?## "}; $$0 ~ "^[[:alnum:]_/%-]+:.*?## " {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) | sort + +.PHONY: fmt +fmt: ## Format code + go fmt ./... + +.PHONY: vet +vet: ## Vet code + go vet ./... + +.PHONY: test +test: ## Run unit tests + go test -short -race -count=1 -v ./... + +.PHONY: lint +lint: vet ## Lint code + @if command -v staticcheck >/dev/null 2>&1; then \ + staticcheck ./...; \ + else \ + echo "staticcheck not installed, skipping (go install honnef.co/go/tools/cmd/staticcheck@latest)"; \ + fi diff --git a/example_test.go b/example_test.go index 6f5c434..d020ab5 100644 --- a/example_test.go +++ b/example_test.go @@ -16,35 +16,42 @@ type bazooka struct { } // Do simulate some bazooking -func (b bazooka) Do(ctx context.Context) { +func (b *bazooka) Do(ctx context.Context) { b.ammo-- fmt.Fprintln(os.Stderr, "bazooking: "+b.targetID) b.bodyCount.Add(1) } func Example() { - pool := workerpool.New[bazooka](3) - defer pool.Shutdown() + // starts a pool with 3 workers + // use the context to cancel the pool without waiting for buffered tasks to complete + pool := workerpool.New[*bazooka](context.TODO(), 3) var bodyCount atomic.Int32 + // list of tasks to perform bazookas := []bazooka{ {ammo: 69, targetID: "foo-id", bodyCount: &bodyCount}, {ammo: 42, targetID: "bar-id", bodyCount: &bodyCount}, {ammo: 11, targetID: "qux-id", bodyCount: &bodyCount}, } + // loop over the tasks initializing and + // submitting the tasks to the pool + for _, bazz := range bazookas { - task := workerpool.Task[bazooka]{ - Fn: func(input bazooka) { + task := workerpool.Task[*bazooka]{ + Fn: func(input *bazooka) { input.Do(context.TODO()) }, - Input: bazz, + Input: &bazz, } pool.Submit(task) } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + fmt.Fprintf(os.Stderr, "shutdown error: %v\n", err) + } fmt.Printf("Body count: %d\n", bodyCount.Load()) // Output: diff --git a/workerpool.go b/workerpool.go index bab352c..2cc2ae6 100644 --- a/workerpool.go +++ b/workerpool.go @@ -2,7 +2,9 @@ package workerpool import ( "context" + "errors" "sync" + "sync/atomic" ) // Input wraps a task's execution. @@ -28,21 +30,34 @@ func WithBuffer[T Input](size int) Option[T] { // Pool maintains fixed worker goroutines processing tasks from a channel. type Pool[T Input] struct { - tasks chan Task[T] - wg sync.WaitGroup + tasks chan Task[T] // channel for tasks waiting to be processed + buffer int // size of the task channel + wg sync.WaitGroup // wait group for worker goroutines + + // immediate termination + ctx context.Context + cancel context.CancelFunc + ungracefulStop atomic.Bool + + // graceful shutdown stop chan struct{} shutdownOnce sync.Once - buffer int } -// New creates a pool with size workers. -func New[T Input](size int, opts ...Option[T]) *Pool[T] { - if size <= 0 { - size = 1 +// New creates a pool with numOfWorkers workers. +// The context can be used to stop the pool immediately, skipping any buffered +// tasks. In-flight tasks will still run to completion. +func New[T Input](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool[T] { + if numOfWorkers <= 0 { + numOfWorkers = 1 } + ctx, cancel := context.WithCancel(ctx) + p := &Pool[T]{ - stop: make(chan struct{}), + ctx: ctx, + cancel: cancel, + stop: make(chan struct{}), } for _, opt := range opts { @@ -51,8 +66,8 @@ func New[T Input](size int, opts ...Option[T]) *Pool[T] { p.tasks = make(chan Task[T], p.buffer) - p.wg.Add(size) - for range size { + p.wg.Add(numOfWorkers) + for range numOfWorkers { go p.worker() } return p @@ -62,8 +77,12 @@ func (p *Pool[T]) worker() { defer p.wg.Done() for { select { + case <-p.ctx.Done(): + // exit without draining buffered tasks + p.ungracefulStop.Store(true) + return case <-p.stop: - // drain remaining buffered tasks + // drain remaining buffered tasks before exiting for { select { case task := <-p.tasks: @@ -78,21 +97,32 @@ func (p *Pool[T]) worker() { } } -// Submit sends a task to the pool. Blocks if all workers are busy. -// Returns false if pool is shut down. +// Submit sends a task to the pool. Blocks if the task channel is full. +// Returns false if the pool is shutting down or the context was cancelled. func (p *Pool[T]) Submit(task Task[T]) bool { select { - case <-p.stop: + case <-p.ctx.Done(): // forcefully terminate via ctx + return false + case <-p.stop: // terminated via graceful shutdown return false case p.tasks <- task: return true } } -// Shutdown stops accepting tasks and waits for active tasks to complete. -func (p *Pool[T]) Shutdown() { +// GracefulShutdown stops accepting new tasks, drains all buffered tasks, +// and waits for in-flight tasks to complete before returning. +// Returns an error if the ctx was cancelled before shutdown completed. +func (p *Pool[T]) GracefulShutdown() error { p.shutdownOnce.Do(func() { close(p.stop) }) + p.wg.Wait() + p.cancel() + + if p.ungracefulStop.Load() { + return errors.New("pool was forcefully terminated before shutdown") + } + return nil } diff --git a/workerpool_test.go b/workerpool_test.go index 2f56e7d..2f93f8b 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -5,7 +5,7 @@ import ( "sync" "sync/atomic" "testing" - "time" + "testing/synctest" ) type testInput struct { @@ -13,15 +13,12 @@ type testInput struct { counter *atomic.Int32 } -func (t testInput) Do(ctx context.Context) { - t.counter.Add(int32(t.value)) -} +func (t testInput) Do(ctx context.Context) { t.counter.Add(int32(t.value)) } func TestPool_Submit(t *testing.T) { t.Parallel() - pool := New[testInput](3) - defer pool.Shutdown() + pool := New[testInput](context.TODO(), 3) var counter atomic.Int32 @@ -38,7 +35,9 @@ func TestPool_Submit(t *testing.T) { } } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("shutdown error: %v", err) + } // sum of 1..10 = 55 if got := counter.Load(); got != 55 { @@ -46,34 +45,99 @@ func TestPool_Submit(t *testing.T) { } } -func TestPool_SubmitAfterShutdown(t *testing.T) { +func TestPool_GracefulShutdown(t *testing.T) { t.Parallel() - pool := New[testInput](2) - pool.Shutdown() + t.Run("Submit task after graceful shutdown", func(t *testing.T) { + t.Parallel() - var counter atomic.Int32 + pool := New[testInput](context.TODO(), 2) - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: &counter}, - } + // immediately shutdown the pool + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("shutdown error: %v", err) + } - if pool.Submit(task) { - t.Error("expected Submit to return false after shutdown") - } + var counter atomic.Int32 + + task := Task[testInput]{ + Fn: func(input testInput) { + input.Do(context.TODO()) + }, + Input: testInput{value: 1, counter: &counter}, + } + + if pool.Submit(task) { + t.Error("expected Submit to return false after shutdown") + } + }) + + t.Run("Graceful shutdown waits for all queued tasks to be complete", func(t *testing.T) { + t.Parallel() + + // use a single worker and a buffered channel + // + // block the worker with a task that won't complete until + // we say so, then fill the buffer with additional tasks, + // call GracefulShutdown while the worker is still blocked, + // then unblock it and verify that GracefulShutdown only returns + // after all buffered tasks have been processed (not just the in-flight one) + + const buffered, workers = 5, 1 + + pool := New(context.TODO(), workers, WithBuffer[testInput](buffered)) + + var counter atomic.Int32 + blocker := make(chan struct{}) + started := make(chan struct{}) + + // submit a task that signals when it starts and then blocks, + // keeping the worker occupied while we fill the buffer + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + close(started) + <-blocker + input.Do(context.TODO()) + }, + Input: testInput{value: 1, counter: &counter}, + }) + + <-started // worker is now blocked + + // fill the buffer while the worker is blocked + for range buffered { + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + input.Do(context.TODO()) + }, + Input: testInput{value: 1, counter: &counter}, + }) + } + + // call GracefulShutdown before unblocking + // it must not return until the buffer is fully drained + shutdownDone := make(chan error, 1) + go func() { + shutdownDone <- pool.GracefulShutdown() + }() + + close(blocker) // release the worker to start draining + + if err := <-shutdownDone; err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } + + // buffered+1 (started+buffered tasks) + if got := counter.Load(); got != buffered+1 { + t.Errorf("expected counter = %d, got %d", buffered+1, got) + } + }) } func TestPool_Concurrency(t *testing.T) { - const ( - workers = 5 - tasks = 100 - ) + const workers, tasks = 5, 100 - pool := New[testInput](workers) - defer pool.Shutdown() + pool := New[testInput](context.TODO(), workers) var ( counter atomic.Int32 @@ -81,9 +145,11 @@ func TestPool_Concurrency(t *testing.T) { ) wg.Add(tasks) + for i := range tasks { go func(val int) { defer wg.Done() + task := Task[testInput]{ Fn: func(input testInput) { input.Do(context.TODO()) @@ -95,7 +161,10 @@ func TestPool_Concurrency(t *testing.T) { } wg.Wait() - pool.Shutdown() + + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } if got := counter.Load(); got != tasks { t.Errorf("expected counter = %d, got %d", tasks, got) @@ -103,8 +172,7 @@ func TestPool_Concurrency(t *testing.T) { } func TestPool_ZeroSize(t *testing.T) { - pool := New[testInput](0) // should default to 1 worker - defer pool.Shutdown() + pool := New[testInput](context.TODO(), 0) // should default to 1 worker var counter atomic.Int32 @@ -119,7 +187,9 @@ func TestPool_ZeroSize(t *testing.T) { t.Error("failed to submit task") } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } if got := counter.Load(); got != 42 { t.Errorf("expected counter = 42, got %d", got) @@ -127,69 +197,73 @@ func TestPool_ZeroSize(t *testing.T) { } func TestPool_Backpressure(t *testing.T) { - pool := New[testInput](1) - defer pool.Shutdown() + t.Parallel() - var counter atomic.Int32 - blocker := make(chan struct{}) + synctest.Test(t, func(t *testing.T) { + pool := New[testInput](context.TODO(), 1) - // submit blocking task - task1 := Task[testInput]{ - Fn: func(input testInput) { - <-blocker // block here - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: &counter}, - } + var counter atomic.Int32 + blocker := make(chan struct{}) - go pool.Submit(task1) - time.Sleep(100 * time.Millisecond) // let worker pick up task1 + // submit blocking task + task1 := Task[testInput]{ + Fn: func(input testInput) { + <-blocker // block here + input.Do(context.TODO()) + }, + Input: testInput{value: 1, counter: &counter}, + } - // try to submit second task should block since channel is unbuffered - submitted := make(chan bool) - task2 := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 2, counter: &counter}, - } + go pool.Submit(task1) - go func() { - submitted <- pool.Submit(task2) - }() + synctest.Wait() // let worker pick up task1 - // verify Submit is blocked - select { - case <-submitted: - t.Error("Submit should be blocked") - case <-time.After(50 * time.Millisecond): - // expected Submit is blocked - } + // try to submit second task should block since channel is unbuffered + task2Submitted := make(chan struct{}) - close(blocker) // unblock worker + go func() { + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + input.Do(context.Background()) + }, + Input: testInput{value: 2, counter: &counter}, + }) + close(task2Submitted) + }() + + // worker still blocked on <-blocker + // task2 goroutine durably blocked on channel send + synctest.Wait() + select { + case <-task2Submitted: + t.Error("Submit should be blocked while worker is busy") + default: + // expected + } - // now Submit should complete - select { - case result := <-submitted: - if !result { - t.Error("expected Submit to succeed") + close(blocker) // release the worker + synctest.Wait() // worker processes both tasks; task2 goroutine exits + select { + case <-task2Submitted: + // expected: Submit completed + default: + t.Error("Submit should have completed after worker was unblocked") } - case <-time.After(100 * time.Millisecond): - t.Error("Submit should have completed") - } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } - if got := counter.Load(); got != 3 { - t.Errorf("expected counter = 3, got %d", got) - } + if got := counter.Load(); got != 3 { + t.Errorf("expected counter = 3, got %d", got) + } + }) } func TestPool_WithBuffer(t *testing.T) { t.Parallel() - pool := New(3, WithBuffer[testInput](10)) - defer pool.Shutdown() + pool := New(context.TODO(), 3, WithBuffer[testInput](10)) var counter atomic.Int32 @@ -206,7 +280,9 @@ func TestPool_WithBuffer(t *testing.T) { } } - pool.Shutdown() + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } // sum of 1..20 = 210 if got := counter.Load(); got != 210 { @@ -215,58 +291,67 @@ func TestPool_WithBuffer(t *testing.T) { } func TestPool_UnbufferedDefault(t *testing.T) { - // without WithBuffer option, channel should be unbuffered - pool := New[testInput](1) - defer pool.Shutdown() + t.Parallel() - counter := &atomic.Int32{} - blocker := make(chan struct{}) - started := make(chan struct{}) + synctest.Test(t, func(t *testing.T) { + pool := New[testInput](t.Context(), 1) - // submit blocking task - go func() { - task := Task[testInput]{ - Fn: func(input testInput) { - close(started) - <-blocker - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: counter}, - } - pool.Submit(task) - }() + counter := &atomic.Int32{} + blocker := make(chan struct{}) + started := make(chan struct{}) - <-started + go func() { + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + close(started) + <-blocker + input.Do(context.Background()) + }, + Input: testInput{value: 1, counter: counter}, + }) + }() - // try to submit another should block since unbuffered - submitted := make(chan bool) - go func() { - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 2, counter: counter}, + <-started // worker has picked up task1 and is blocked + + // second Submit must block since channel is unbuffered and worker is busy + task2Submitted := make(chan struct{}) + go func() { + pool.Submit(Task[testInput]{ + Fn: func(input testInput) { + input.Do(context.Background()) + }, + Input: testInput{value: 2, counter: counter}, + }) + close(task2Submitted) + }() + + // task2 goroutine is durably blocked on the channel send + synctest.Wait() + + select { + case <-task2Submitted: + t.Error("submit should be blocked on unbuffered channel") + default: + // task2Submitted is not yet closed. Submit is still blocked } - pool.Submit(task) - submitted <- true - }() - - // verify submit is blocked - select { - case <-submitted: - t.Error("submit should be blocked on unbuffered channel") - case <-time.After(50 * time.Millisecond): - // expected - } - close(blocker) + close(blocker) // release the worker - // wait for second submit to complete - <-submitted + synctest.Wait() // worker processes both tasks and task2 goroutine exits - pool.Shutdown() + select { + case <-task2Submitted: + // task2Submitted is closed and Submit completed after worker was unblocked, as expected + default: + t.Error("Submit should have completed after worker was unblocked") + } - if got := counter.Load(); got != 3 { - t.Errorf("expected counter = 3, got %d", got) - } + if err := pool.GracefulShutdown(); err != nil { + t.Errorf("unexpected shutdown error: %v", err) + } + + if got := counter.Load(); got != 3 { + t.Errorf("expected counter = 3, got %d", got) + } + }) }