Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.coverage
20 changes: 13 additions & 7 deletions workerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ func (p *Pool[T]) worker() {
case entry := <-p.entries:
entry.job.Do(entry.ctx)
default:
// channel is empty. since p.stop is closed,
// no more tasks can be submitted
return
}
}
Expand Down Expand Up @@ -115,15 +117,19 @@ func (p *Pool[T]) Submit(ctx context.Context, task T) bool {
// and waits for in-flight tasks to complete before returning.
// Returns an error if the ctx was cancelled before shutdown completed.
func (p *Pool[T]) GracefulShutdown() error {
p.shutdownOnce.Do(func() {
close(p.stop)
})

p.wg.Wait()
p.cancel()

if p.ungracefulStop.Load() {
return errors.New("pool was forcefully terminated before shutdown")
}

p.shutdownOnce.Do(func() {
close(p.stop)
p.wg.Wait()
p.cancel()

// only close(p.entries) with a lock here and
// a read lock in Submit otherwise senders will panic =]
// but it's just a good to have, since p.stop is closed
// and submit already checks for that
})
return nil
}
55 changes: 40 additions & 15 deletions workerpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ func TestPool_Submit(t *testing.T) {
pool := New[testTask](context.TODO(), 3)

var counter atomic.Int32

// submit tasks
for i := 1; i <= 10; i++ {
task := testTask{
value: i,
Expand Down Expand Up @@ -318,23 +316,50 @@ func TestPool_GracefulShutdown(t *testing.T) {
t.Run("Submit task after graceful shutdown", func(t *testing.T) {
t.Parallel()

pool := New[testTask](context.TODO(), 2)
synctest.Test(t, func(t *testing.T) {
pool := New[testTask](context.TODO(), 2)

// immediately shutdown the pool
if err := pool.GracefulShutdown(); err != nil {
t.Errorf("shutdown error: %v", err)
}
var counter atomic.Int32
blocker := make(chan struct{})

var counter atomic.Int32
task1 := testTask{
value: 1,
counter: &counter,
fn: func(context.Context) {
<-blocker
},
}

task := testTask{
value: 1,
counter: &counter,
}
if !pool.Submit(context.TODO(), task1) {
t.Error("expected Submit to return true")
}

if pool.Submit(context.TODO(), task) {
t.Error("expected Submit to return false after shutdown")
}
// wait for worker to pick up the task and block
synctest.Wait()

// immediately shutdown the pool
// will close(p.stop) but block on p.wg.Wait()
shutdownDone := make(chan error, 1)
go func() {
shutdownDone <- pool.GracefulShutdown()
}()

synctest.Wait() // wait shutdown to propagate

task2 := testTask{
value: 1,
counter: &counter,
}

if pool.Submit(context.TODO(), task2) {
t.Error("expected Submit to return false after shutdown")
}
close(blocker)

if err := <-shutdownDone; err != nil {
t.Errorf("shutdown error: %v", err)
}
})
})

t.Run("Graceful shutdown waits for all queued tasks to be complete", func(t *testing.T) {
Expand Down
Loading