diff --git a/README.md b/README.md index c23ad8f..374dfd7 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ func main() { } for i := range bazookas { - pool.Submit(context.TODO(), &bazookas[i]) + _ = pool.Submit(context.TODO(), &bazookas[i]) } _ = pool.GracefulShutdown() diff --git a/example_test.go b/example_test.go index c3cf129..c91917f 100644 --- a/example_test.go +++ b/example_test.go @@ -36,7 +36,7 @@ func Example() { } for i := range bazookas { - pool.Submit(context.TODO(), &bazookas[i]) + _ = pool.Submit(context.TODO(), &bazookas[i]) } if err := pool.GracefulShutdown(); err != nil { diff --git a/workerpool.go b/workerpool.go index 769250f..ae5e430 100644 --- a/workerpool.go +++ b/workerpool.go @@ -98,18 +98,23 @@ func (p *Pool[T]) worker() { } } +var ( + ErrPoolClosed = errors.New("pool is closed") + ErrTaskCancelled = errors.New("task context cancelled") +) + // Submit sends a task to the pool. Blocks if the task channel is full. // Returns false if the pool is shutting down or the context was cancelled. -func (p *Pool[T]) Submit(ctx context.Context, task T) bool { +func (p *Pool[T]) Submit(ctx context.Context, task T) error { select { case <-ctx.Done(): - return false + return ErrTaskCancelled case <-p.ctx.Done(): // forcefully terminate via ctx - return false + return ErrPoolClosed case <-p.stop: // terminated via graceful shutdown - return false + return ErrPoolClosed case p.entries <- entry[T]{ctx: ctx, job: task}: - return true + return nil } } diff --git a/workerpool_test.go b/workerpool_test.go index 0a06279..78fc437 100644 --- a/workerpool_test.go +++ b/workerpool_test.go @@ -2,6 +2,7 @@ package workerpool import ( "context" + "errors" "sync" "sync/atomic" "testing" @@ -36,7 +37,7 @@ func TestNew(t *testing.T) { counter: &counter, } - if !pool.Submit(context.TODO(), task) { + if err := pool.Submit(context.TODO(), task); err != nil { t.Error("failed to submit task") } @@ -64,7 +65,7 @@ func TestPool_WithBuffer(t *testing.T) { counter: &counter, } - if !pool.Submit(context.TODO(), task) { + if err := pool.Submit(context.TODO(), task); err != nil { t.Error("failed to submit task") } } @@ -94,7 +95,7 @@ func TestPool_Submit(t *testing.T) { counter: &counter, } - if !pool.Submit(context.TODO(), task) { + if err := pool.Submit(context.TODO(), task); err != nil { t.Error("failed to submit task") } } @@ -330,7 +331,7 @@ func TestPool_GracefulShutdown(t *testing.T) { }, } - if !pool.Submit(context.TODO(), task1) { + if err := pool.Submit(context.TODO(), task1); err != nil { t.Error("expected Submit to return true") } @@ -351,9 +352,15 @@ func TestPool_GracefulShutdown(t *testing.T) { counter: &counter, } - if pool.Submit(context.TODO(), task2) { - t.Error("expected Submit to return false after shutdown") + err := pool.Submit(context.TODO(), task2) + if err == nil { + t.Error("expected Submit to return error after shutdown") } + + if !errors.Is(err, ErrPoolClosed) { + t.Errorf("expected error to be %v, got %v", ErrPoolClosed, err) + } + close(blocker) if err := <-shutdownDone; err != nil {