Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
.DEFAULT_GOAL := help

PROJECT_NAME := workerpool

.PHONY: help
help:
@echo "------------------------------------------------------------------------"
@echo "${PROJECT_NAME}"
@echo "------------------------------------------------------------------------"
@awk 'BEGIN {FS = ":.*?## "}; $$0 ~ "^[[:alnum:]_/%-]+:.*?## " {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) | sort

.PHONY: fmt
fmt: ## Format code
go fmt ./...

.PHONY: vet
vet: ## Vet code
go vet ./...

.PHONY: test
test: ## Run unit tests
go test -short -race -count=1 -v ./...

.PHONY: lint
lint: vet ## Lint code
@if command -v staticcheck >/dev/null 2>&1; then \
staticcheck ./...; \
else \
echo "staticcheck not installed, skipping (go install honnef.co/go/tools/cmd/staticcheck@latest)"; \
fi
21 changes: 14 additions & 7 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,42 @@ type bazooka struct {
}

// Do simulate some bazooking
func (b bazooka) Do(ctx context.Context) {
func (b *bazooka) Do(ctx context.Context) {
b.ammo--
fmt.Fprintln(os.Stderr, "bazooking: "+b.targetID)
b.bodyCount.Add(1)
}

func Example() {
pool := workerpool.New[bazooka](3)
defer pool.Shutdown()
// starts a pool with 3 workers
// use the context to cancel the pool without waiting for buffered tasks to complete
pool := workerpool.New[*bazooka](context.TODO(), 3)

var bodyCount atomic.Int32

// list of tasks to perform
bazookas := []bazooka{
{ammo: 69, targetID: "foo-id", bodyCount: &bodyCount},
{ammo: 42, targetID: "bar-id", bodyCount: &bodyCount},
{ammo: 11, targetID: "qux-id", bodyCount: &bodyCount},
}

// loop over the tasks initializing and
// submitting the tasks to the pool

for _, bazz := range bazookas {
task := workerpool.Task[bazooka]{
Fn: func(input bazooka) {
task := workerpool.Task[*bazooka]{
Fn: func(input *bazooka) {
input.Do(context.TODO())
},
Input: bazz,
Input: &bazz,
}
pool.Submit(task)
}

pool.Shutdown()
if err := pool.GracefulShutdown(); err != nil {
fmt.Fprintf(os.Stderr, "shutdown error: %v\n", err)
}

fmt.Printf("Body count: %d\n", bodyCount.Load())
// Output:
Expand Down
62 changes: 46 additions & 16 deletions workerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package workerpool

import (
"context"
"errors"
"sync"
"sync/atomic"
)

// Input wraps a task's execution.
Expand All @@ -28,21 +30,34 @@ func WithBuffer[T Input](size int) Option[T] {

// Pool maintains fixed worker goroutines processing tasks from a channel.
type Pool[T Input] struct {
tasks chan Task[T]
wg sync.WaitGroup
tasks chan Task[T] // channel for tasks waiting to be processed
buffer int // size of the task channel
wg sync.WaitGroup // wait group for worker goroutines

// immediate termination
ctx context.Context
cancel context.CancelFunc
ungracefulStop atomic.Bool

// graceful shutdown
stop chan struct{}
shutdownOnce sync.Once
buffer int
}

// New creates a pool with size workers.
func New[T Input](size int, opts ...Option[T]) *Pool[T] {
if size <= 0 {
size = 1
// New creates a pool with numOfWorkers workers.
// The context can be used to stop the pool immediately, skipping any buffered
// tasks. In-flight tasks will still run to completion.
func New[T Input](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool[T] {
if numOfWorkers <= 0 {
numOfWorkers = 1
}

ctx, cancel := context.WithCancel(ctx)

p := &Pool[T]{
stop: make(chan struct{}),
ctx: ctx,
cancel: cancel,
stop: make(chan struct{}),
}

for _, opt := range opts {
Expand All @@ -51,8 +66,8 @@ func New[T Input](size int, opts ...Option[T]) *Pool[T] {

p.tasks = make(chan Task[T], p.buffer)

p.wg.Add(size)
for range size {
p.wg.Add(numOfWorkers)
for range numOfWorkers {
go p.worker()
}
return p
Expand All @@ -62,8 +77,12 @@ func (p *Pool[T]) worker() {
defer p.wg.Done()
for {
select {
case <-p.ctx.Done():
// exit without draining buffered tasks
p.ungracefulStop.Store(true)
return
case <-p.stop:
// drain remaining buffered tasks
// drain remaining buffered tasks before exiting
for {
select {
case task := <-p.tasks:
Expand All @@ -78,21 +97,32 @@ func (p *Pool[T]) worker() {
}
}

// Submit sends a task to the pool. Blocks if all workers are busy.
// Returns false if pool is shut down.
// Submit sends a task to the pool. Blocks if the task channel is full.
// Returns false if the pool is shutting down or the context was cancelled.
func (p *Pool[T]) Submit(task Task[T]) bool {
select {
case <-p.stop:
case <-p.ctx.Done(): // forcefully terminate via ctx
return false
case <-p.stop: // terminated via graceful shutdown
return false
case p.tasks <- task:
return true
}
}

// Shutdown stops accepting tasks and waits for active tasks to complete.
func (p *Pool[T]) Shutdown() {
// GracefulShutdown stops accepting new tasks, drains all buffered tasks,
// and waits for in-flight tasks to complete before returning.
// Returns an error if the ctx was cancelled before shutdown completed.
func (p *Pool[T]) GracefulShutdown() error {
p.shutdownOnce.Do(func() {
close(p.stop)
})

p.wg.Wait()
p.cancel()

if p.ungracefulStop.Load() {
return errors.New("pool was forcefully terminated before shutdown")
}
return nil
}
Loading
Loading