Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func main() {
}

for i := range bazookas {
pool.Submit(context.TODO(), &bazookas[i])
_ = pool.Submit(context.TODO(), &bazookas[i])
}

_ = pool.GracefulShutdown()
Expand Down
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func Example() {
}

for i := range bazookas {
pool.Submit(context.TODO(), &bazookas[i])
_ = pool.Submit(context.TODO(), &bazookas[i])
}

if err := pool.GracefulShutdown(); err != nil {
Expand Down
15 changes: 10 additions & 5 deletions workerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,23 @@ func (p *Pool[T]) worker() {
}
}

var (
ErrPoolClosed = errors.New("pool is closed")
ErrTaskCancelled = errors.New("task context cancelled")
)

// Submit sends a task to the pool. Blocks if the task channel is full.
// Returns false if the pool is shutting down or the context was cancelled.
func (p *Pool[T]) Submit(ctx context.Context, task T) bool {
func (p *Pool[T]) Submit(ctx context.Context, task T) error {
select {
case <-ctx.Done():
return false
return ErrTaskCancelled
case <-p.ctx.Done(): // forcefully terminate via ctx
return false
return ErrPoolClosed
case <-p.stop: // terminated via graceful shutdown
return false
return ErrPoolClosed
case p.entries <- entry[T]{ctx: ctx, job: task}:
return true
return nil
}
}

Expand Down
19 changes: 13 additions & 6 deletions workerpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package workerpool

import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -36,7 +37,7 @@ func TestNew(t *testing.T) {
counter: &counter,
}

if !pool.Submit(context.TODO(), task) {
if err := pool.Submit(context.TODO(), task); err != nil {
t.Error("failed to submit task")
}

Expand Down Expand Up @@ -64,7 +65,7 @@ func TestPool_WithBuffer(t *testing.T) {
counter: &counter,
}

if !pool.Submit(context.TODO(), task) {
if err := pool.Submit(context.TODO(), task); err != nil {
t.Error("failed to submit task")
}
}
Expand Down Expand Up @@ -94,7 +95,7 @@ func TestPool_Submit(t *testing.T) {
counter: &counter,
}

if !pool.Submit(context.TODO(), task) {
if err := pool.Submit(context.TODO(), task); err != nil {
t.Error("failed to submit task")
}
}
Expand Down Expand Up @@ -330,7 +331,7 @@ func TestPool_GracefulShutdown(t *testing.T) {
},
}

if !pool.Submit(context.TODO(), task1) {
if err := pool.Submit(context.TODO(), task1); err != nil {
t.Error("expected Submit to return true")
}

Expand All @@ -351,9 +352,15 @@ func TestPool_GracefulShutdown(t *testing.T) {
counter: &counter,
}

if pool.Submit(context.TODO(), task2) {
t.Error("expected Submit to return false after shutdown")
err := pool.Submit(context.TODO(), task2)
if err == nil {
t.Error("expected Submit to return error after shutdown")
}

if !errors.Is(err, ErrPoolClosed) {
t.Errorf("expected error to be %v, got %v", ErrPoolClosed, err)
}

close(blocker)

if err := <-shutdownDone; err != nil {
Expand Down
Loading