diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad6ef14..57e3cd2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,6 +63,24 @@ jobs: - name: Run go vet run: go vet ./... + lint: + name: Lint + needs: deps + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + + - name: Install staticcheck + run: go install honnef.co/go/tools/cmd/staticcheck@latest + + - name: Run staticcheck + run: staticcheck ./... + test: name: Test needs: deps diff --git a/example_test.go b/example_test.go index d020ab5..9376095 100644 --- a/example_test.go +++ b/example_test.go @@ -29,24 +29,14 @@ func Example() { var bodyCount atomic.Int32 - // list of tasks to perform bazookas := []bazooka{ {ammo: 69, targetID: "foo-id", bodyCount: &bodyCount}, {ammo: 42, targetID: "bar-id", bodyCount: &bodyCount}, {ammo: 11, targetID: "qux-id", bodyCount: &bodyCount}, } - // loop over the tasks initializing and - // submitting the tasks to the pool - for _, bazz := range bazookas { - task := workerpool.Task[*bazooka]{ - Fn: func(input *bazooka) { - input.Do(context.TODO()) - }, - Input: &bazz, - } - pool.Submit(task) + pool.Submit(&bazz) } if err := pool.GracefulShutdown(); err != nil { diff --git a/workerpool.go b/workerpool.go index 2cc2ae6..c97dd57 100644 --- a/workerpool.go +++ b/workerpool.go @@ -7,30 +7,24 @@ import ( "sync/atomic" ) -// Input wraps a task's execution. -type Input interface { +// Task defines the interface for a task to be executed by a worker pool. +type Task interface { Do(context.Context) } -// Task bundles a function with its input. -type Task[T Input] struct { - Fn func(T) - Input T -} - // Option configures a Pool. -type Option[T Input] func(*Pool[T]) +type Option[T Task] func(*Pool[T]) // WithBuffer sets the task channel buffer size. -func WithBuffer[T Input](size int) Option[T] { +func WithBuffer[T Task](size int) Option[T] { return func(p *Pool[T]) { p.buffer = size } } // Pool maintains fixed worker goroutines processing tasks from a channel. -type Pool[T Input] struct { - tasks chan Task[T] // channel for tasks waiting to be processed +type Pool[T Task] struct { + tasks chan T // channel for tasks waiting to be processed buffer int // size of the task channel wg sync.WaitGroup // wait group for worker goroutines @@ -47,7 +41,7 @@ type Pool[T Input] struct { // New creates a pool with numOfWorkers workers. // The context can be used to stop the pool immediately, skipping any buffered // tasks. In-flight tasks will still run to completion. -func New[T Input](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool[T] { +func New[T Task](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool[T] { if numOfWorkers <= 0 { numOfWorkers = 1 } @@ -64,7 +58,7 @@ func New[T Input](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Poo opt(p) } - p.tasks = make(chan Task[T], p.buffer) + p.tasks = make(chan T, p.buffer) p.wg.Add(numOfWorkers) for range numOfWorkers { @@ -86,20 +80,20 @@ func (p *Pool[T]) worker() { for { select { case task := <-p.tasks: - task.Fn(task.Input) + task.Do(p.ctx) default: return } } case task := <-p.tasks: - task.Fn(task.Input) + task.Do(p.ctx) } } } // Submit sends a task to the pool. Blocks if the task channel is full. // Returns false if the pool is shutting down or the context was cancelled. -func (p *Pool[T]) Submit(task Task[T]) bool { +func (p *Pool[T]) Submit(task T) bool { select { case <-p.ctx.Done(): // forcefully terminate via ctx return false diff --git a/workerpool_test.go b/workerpool_test.go index 2f93f8b..87ca323 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -8,28 +8,33 @@ import ( "testing/synctest" ) -type testInput struct { +type testTask struct { value int counter *atomic.Int32 + fn func() } -func (t testInput) Do(ctx context.Context) { t.counter.Add(int32(t.value)) } +func (t testTask) Do(ctx context.Context) { + if t.fn != nil { + t.fn() + } + t.counter.Add(int32(t.value)) +} func TestPool_Submit(t *testing.T) { t.Parallel() - pool := New[testInput](context.TODO(), 3) + pool := New[testTask](context.TODO(), 3) var counter atomic.Int32 // submit tasks for i := 1; i <= 10; i++ { - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: i, counter: &counter}, + task := testTask{ + value: i, + counter: &counter, } + if !pool.Submit(task) { t.Error("failed to submit task") } @@ -51,7 +56,7 @@ func TestPool_GracefulShutdown(t *testing.T) { t.Run("Submit task after graceful shutdown", func(t *testing.T) { t.Parallel() - pool := New[testInput](context.TODO(), 2) + pool := New[testTask](context.TODO(), 2) // immediately shutdown the pool if err := pool.GracefulShutdown(); err != nil { @@ -60,11 +65,9 @@ func TestPool_GracefulShutdown(t *testing.T) { var counter atomic.Int32 - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: &counter}, + task := testTask{ + value: 1, + counter: &counter, } if pool.Submit(task) { @@ -85,7 +88,7 @@ func TestPool_GracefulShutdown(t *testing.T) { const buffered, workers = 5, 1 - pool := New(context.TODO(), workers, WithBuffer[testInput](buffered)) + pool := New(context.TODO(), workers, WithBuffer[testTask](buffered)) var counter atomic.Int32 blocker := make(chan struct{}) @@ -93,24 +96,21 @@ func TestPool_GracefulShutdown(t *testing.T) { // submit a task that signals when it starts and then blocks, // keeping the worker occupied while we fill the buffer - pool.Submit(Task[testInput]{ - Fn: func(input testInput) { + pool.Submit(testTask{ + fn: func() { close(started) <-blocker - input.Do(context.TODO()) }, - Input: testInput{value: 1, counter: &counter}, + value: 1, + counter: &counter, }) <-started // worker is now blocked // fill the buffer while the worker is blocked for range buffered { - pool.Submit(Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: &counter}, + pool.Submit(testTask{ + value: 1, counter: &counter, }) } @@ -132,12 +132,45 @@ func TestPool_GracefulShutdown(t *testing.T) { t.Errorf("expected counter = %d, got %d", buffered+1, got) } }) + + t.Run("Context cancellation terminates pool without draining buffered tasks", func(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + + pool := New[testTask](ctx, 1) + + var counter atomic.Int32 + blocker := make(chan struct{}) + started := make(chan struct{}) + + pool.Submit(testTask{ + fn: func() { + close(started) + <-blocker + }, + value: 1, + counter: &counter, + }) + + <-started // worker is blocked + cancel() // cancel the caller's context + close(blocker) // unblock the worker + + synctest.Wait() + + if err := pool.GracefulShutdown(); err == nil { + t.Error("expected GracefulShutdown to return an error after context cancellation") + } + }) + }) } func TestPool_Concurrency(t *testing.T) { const workers, tasks = 5, 100 - pool := New[testInput](context.TODO(), workers) + pool := New[testTask](context.TODO(), workers) var ( counter atomic.Int32 @@ -150,11 +183,9 @@ func TestPool_Concurrency(t *testing.T) { go func(val int) { defer wg.Done() - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 1, counter: &counter}, + task := testTask{ + value: 1, + counter: &counter, } pool.Submit(task) }(i) @@ -172,15 +203,13 @@ func TestPool_Concurrency(t *testing.T) { } func TestPool_ZeroSize(t *testing.T) { - pool := New[testInput](context.TODO(), 0) // should default to 1 worker + pool := New[testTask](context.TODO(), 0) // should default to 1 worker var counter atomic.Int32 - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: 42, counter: &counter}, + task := testTask{ + value: 42, + counter: &counter, } if !pool.Submit(task) { @@ -200,18 +229,18 @@ func TestPool_Backpressure(t *testing.T) { t.Parallel() synctest.Test(t, func(t *testing.T) { - pool := New[testInput](context.TODO(), 1) + pool := New[testTask](context.TODO(), 1) var counter atomic.Int32 blocker := make(chan struct{}) // submit blocking task - task1 := Task[testInput]{ - Fn: func(input testInput) { + task1 := testTask{ + fn: func() { <-blocker // block here - input.Do(context.TODO()) }, - Input: testInput{value: 1, counter: &counter}, + value: 1, + counter: &counter, } go pool.Submit(task1) @@ -222,11 +251,9 @@ func TestPool_Backpressure(t *testing.T) { task2Submitted := make(chan struct{}) go func() { - pool.Submit(Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.Background()) - }, - Input: testInput{value: 2, counter: &counter}, + pool.Submit(testTask{ + value: 2, + counter: &counter, }) close(task2Submitted) }() @@ -263,18 +290,17 @@ func TestPool_Backpressure(t *testing.T) { func TestPool_WithBuffer(t *testing.T) { t.Parallel() - pool := New(context.TODO(), 3, WithBuffer[testInput](10)) + pool := New(context.TODO(), 3, WithBuffer[testTask](10)) var counter atomic.Int32 // submit 20 tasks for i := 1; i <= 20; i++ { - task := Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.TODO()) - }, - Input: testInput{value: i, counter: &counter}, + task := testTask{ + value: i, + counter: &counter, } + if !pool.Submit(task) { t.Error("failed to submit task") } @@ -294,20 +320,20 @@ func TestPool_UnbufferedDefault(t *testing.T) { t.Parallel() synctest.Test(t, func(t *testing.T) { - pool := New[testInput](t.Context(), 1) + pool := New[testTask](t.Context(), 1) counter := &atomic.Int32{} blocker := make(chan struct{}) started := make(chan struct{}) go func() { - pool.Submit(Task[testInput]{ - Fn: func(input testInput) { + pool.Submit(testTask{ + fn: func() { close(started) <-blocker - input.Do(context.Background()) }, - Input: testInput{value: 1, counter: counter}, + value: 1, + counter: counter, }) }() @@ -316,11 +342,9 @@ func TestPool_UnbufferedDefault(t *testing.T) { // second Submit must block since channel is unbuffered and worker is busy task2Submitted := make(chan struct{}) go func() { - pool.Submit(Task[testInput]{ - Fn: func(input testInput) { - input.Do(context.Background()) - }, - Input: testInput{value: 2, counter: counter}, + pool.Submit(testTask{ + value: 2, + counter: counter, }) close(task2Submitted) }()