Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,30 @@ type bazooka struct {
bodyCount *atomic.Int32
}

// Do simulate some bazooking, and
// implement the generic constraint
func (b bazooka) Do(ctx context.Context) {
// Do simulate some bazooking
func (b *bazooka) Do(ctx context.Context) {
b.ammo--
fmt.Fprintln(os.Stderr, "bazooking: "+b.targetID)
b.bodyCount.Add(1)
}

...
pool := workerpool.New[bazooka](3)
defer pool.Shutdown()
func main() {
pool := workerpool.New[*bazooka](context.TODO(), 3)

var bodyCount atomic.Int32

bazookas := []bazooka{
{ammo: 69, targetID: "foo-id", bodyCount: &bodyCount},
{ammo: 42, targetID: "bar-id", bodyCount: &bodyCount},
{ammo: 11, targetID: "qux-id", bodyCount: &bodyCount},
}

for _, bazz := range bazookas {
task := workerpool.Task[bazooka]{
Fn: func(input bazooka) {
input.Do(context.TODO())
},
Input: bazz,
var bodyCount atomic.Int32

bazookas := []bazooka{
{ammo: 69, targetID: "foo-id", bodyCount: &bodyCount},
{ammo: 42, targetID: "bar-id", bodyCount: &bodyCount},
{ammo: 11, targetID: "qux-id", bodyCount: &bodyCount},
}

for i := range bazookas {
pool.Submit(context.TODO(), &bazookas[i])
}
pool.Submit(task)

_ = pool.GracefulShutdown()

fmt.Printf("Body count: %d\n", bodyCount.Load())
}
```
4 changes: 2 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func Example() {
{ammo: 11, targetID: "qux-id", bodyCount: &bodyCount},
}

for _, bazz := range bazookas {
pool.Submit(&bazz)
for i := range bazookas {
pool.Submit(context.TODO(), &bazookas[i])
}

if err := pool.GracefulShutdown(); err != nil {
Expand Down
43 changes: 25 additions & 18 deletions workerpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ type Task interface {
Do(context.Context)
}

type entry[T Task] struct {
ctx context.Context
job T
}

// Option configures a Pool.
type Option[T Task] func(*Pool[T])

Expand All @@ -24,9 +29,9 @@ func WithBuffer[T Task](size int) Option[T] {

// Pool maintains fixed worker goroutines processing tasks from a channel.
type Pool[T Task] struct {
tasks chan T // channel for tasks waiting to be processed
buffer int // size of the task channel
wg sync.WaitGroup // wait group for worker goroutines
entries chan entry[T] // channel for jobs waiting to be processed
buffer int // size of the task channel
wg sync.WaitGroup // wait group for worker goroutines

// immediate termination
ctx context.Context
Expand All @@ -38,18 +43,18 @@ type Pool[T Task] struct {
shutdownOnce sync.Once
}

// New creates a pool with numOfWorkers workers.
// New creates a pool with number of available workers.
// The context can be used to stop the pool immediately, skipping any buffered
// tasks. In-flight tasks will still run to completion.
func New[T Task](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool[T] {
if numOfWorkers <= 0 {
numOfWorkers = 1
func New[T Task](ctx context.Context, workers int, opts ...Option[T]) *Pool[T] {
if workers <= 0 {
workers = 1
}

ctx, cancel := context.WithCancel(ctx)
poolCtx, cancel := context.WithCancel(ctx)

p := &Pool[T]{
ctx: ctx,
ctx: poolCtx,
cancel: cancel,
stop: make(chan struct{}),
}
Expand All @@ -58,10 +63,10 @@ func New[T Task](ctx context.Context, numOfWorkers int, opts ...Option[T]) *Pool
opt(p)
}

p.tasks = make(chan T, p.buffer)
p.entries = make(chan entry[T], p.buffer)

p.wg.Add(numOfWorkers)
for range numOfWorkers {
p.wg.Add(workers)
for range workers {
go p.worker()
}
return p
Expand All @@ -79,27 +84,29 @@ func (p *Pool[T]) worker() {
// drain remaining buffered tasks before exiting
for {
select {
case task := <-p.tasks:
task.Do(p.ctx)
case entry := <-p.entries:
entry.job.Do(entry.ctx)
default:
return
}
}
case task := <-p.tasks:
task.Do(p.ctx)
case entry := <-p.entries:
entry.job.Do(entry.ctx)
}
}
}

// Submit sends a task to the pool. Blocks if the task channel is full.
// Returns false if the pool is shutting down or the context was cancelled.
func (p *Pool[T]) Submit(task T) bool {
func (p *Pool[T]) Submit(ctx context.Context, task T) bool {
select {
case <-ctx.Done():
return false
case <-p.ctx.Done(): // forcefully terminate via ctx
return false
case <-p.stop: // terminated via graceful shutdown
return false
case p.tasks <- task:
case p.entries <- entry[T]{ctx: ctx, job: task}:
return true
}
}
Expand Down
Loading
Loading