Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ debug
junk/

# Extension artifacts
*.zip
extensions/**/*.zip
extensions/**/assets/

Expand Down
62 changes: 62 additions & 0 deletions batch/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package batch

import (
"context"
"net/url"
"sync"
"time"
)

// DomainLimiter enforces a minimum delay between requests to the same domain.
type DomainLimiter struct {
mu sync.Mutex
minDelay time.Duration
last map[string]time.Time
}

// NewDomainLimiter creates a rate limiter with the given minimum delay per domain.
func NewDomainLimiter(minDelay time.Duration) *DomainLimiter {
return &DomainLimiter{
minDelay: minDelay,
last: make(map[string]time.Time),
}
}

// Wait blocks until it's safe to make a request to the given URL's domain.
// It reserves the slot before returning, so concurrent callers are serialized per domain.
func (d *DomainLimiter) Wait(ctx context.Context, rawURL string) error {
domain := extractDomain(rawURL)

d.mu.Lock()
lastReq, ok := d.last[domain]
now := time.Now()

if ok {
elapsed := now.Sub(lastReq)
if elapsed < d.minDelay {
wait := d.minDelay - elapsed
// Reserve the slot before releasing the lock
d.last[domain] = now.Add(wait)
d.mu.Unlock()

select {
case <-time.After(wait):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
}

d.last[domain] = now
d.mu.Unlock()
return nil
}

func extractDomain(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil || u.Hostname() == "" {
return rawURL
}
return u.Hostname()
}
Comment on lines +56 to +62
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: extractDomain returns empty string for scheme-less URLs, causing incorrect rate limiting

Go's url.Parse is very permissive — it never returns an error for scheme-less inputs like "example.com/page". Instead, it parses the entire string as a relative path with an empty host, so u.Hostname() returns "". This means all scheme-less URLs (a realistic user input) share the same "" key in the rate limiter's last map, causing requests to completely unrelated domains to be unnecessarily serialized against each other.

The return rawURL fallback on error is effectively dead code since url.Parse almost never errors.

Suggested change
func extractDomain(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil {
return rawURL
}
return u.Hostname()
}
func extractDomain(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil || u.Hostname() == "" {
return rawURL
}
return u.Hostname()
}

111 changes: 111 additions & 0 deletions batch/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package batch

import (
"context"
"sync"
"testing"
"time"
)

func TestDomainLimiter_EnforcesDelay(t *testing.T) {
limiter := NewDomainLimiter(100 * time.Millisecond)
ctx := context.Background()

start := time.Now()

// First request should not wait
if err := limiter.Wait(ctx, "https://example.com/a"); err != nil {
t.Fatal(err)
}

// Second request to same domain should wait ~100ms
if err := limiter.Wait(ctx, "https://example.com/b"); err != nil {
t.Fatal(err)
}

elapsed := time.Since(start)
if elapsed < 90*time.Millisecond {
t.Errorf("expected >= 100ms delay, got %v", elapsed)
}
}

func TestDomainLimiter_DifferentDomains(t *testing.T) {
limiter := NewDomainLimiter(200 * time.Millisecond)
ctx := context.Background()

start := time.Now()

// Two different domains should not block each other
if err := limiter.Wait(ctx, "https://a.com/page"); err != nil {
t.Fatal(err)
}
if err := limiter.Wait(ctx, "https://b.com/page"); err != nil {
t.Fatal(err)
}

elapsed := time.Since(start)
if elapsed > 50*time.Millisecond {
t.Errorf("different domains should not wait, got %v", elapsed)
}
}

func TestDomainLimiter_ContextCancellation(t *testing.T) {
limiter := NewDomainLimiter(5 * time.Second)
ctx, cancel := context.WithCancel(context.Background())

// First request to establish the domain
if err := limiter.Wait(ctx, "https://example.com/a"); err != nil {
t.Fatal(err)
}

// Cancel before second request completes
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()

err := limiter.Wait(ctx, "https://example.com/b")
if err == nil {
t.Error("expected context cancellation error")
}
}

func TestDomainLimiter_ConcurrentSameDomain(t *testing.T) {
limiter := NewDomainLimiter(50 * time.Millisecond)
ctx := context.Background()

start := time.Now()
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
limiter.Wait(ctx, "https://same.com/page")
}()
}
wg.Wait()

elapsed := time.Since(start)
// 3 requests with 50ms delay = at least ~100ms (first is free, second waits, third waits)
if elapsed < 80*time.Millisecond {
t.Errorf("expected >= 100ms for 3 concurrent same-domain requests, got %v", elapsed)
}
}

func TestExtractDomain(t *testing.T) {
tests := []struct {
url string
domain string
}{
{"https://example.com/path", "example.com"},
{"http://sub.example.com:8080/page", "sub.example.com"},
{"not-a-url", "not-a-url"},
{"example.com/page", "example.com/page"},
}
for _, tt := range tests {
got := extractDomain(tt.url)
if got != tt.domain {
t.Errorf("extractDomain(%q) = %q, want %q", tt.url, got, tt.domain)
}
}
}
41 changes: 41 additions & 0 deletions batch/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package batch

import (
"context"
"time"
)

// RetryConfig controls exponential backoff retry behavior.
type RetryConfig struct {
MaxAttempts int // Total attempts (1 = no retry)
InitDelay time.Duration // Delay before first retry; doubles each attempt
}

// Do executes fn with retries on error. It uses exponential backoff and
// respects context cancellation. Only retries if shouldRetry returns true.
func (rc *RetryConfig) Do(ctx context.Context, fn func() error, shouldRetry func(error) bool) error {
var lastErr error
delay := rc.InitDelay

for attempt := 0; attempt < rc.MaxAttempts; attempt++ {
lastErr = fn()
if lastErr == nil {
return nil
}

if !shouldRetry(lastErr) {
return lastErr
}

// Don't sleep after the last attempt
if attempt < rc.MaxAttempts-1 {
select {
case <-time.After(delay):
delay *= 2
case <-ctx.Done():
return ctx.Err()
}
}
}
return lastErr
}
98 changes: 98 additions & 0 deletions batch/retry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package batch

import (
"context"
"errors"
"testing"
"time"
)

func TestRetryConfig_NoRetryOnSuccess(t *testing.T) {
rc := &RetryConfig{MaxAttempts: 3, InitDelay: 10 * time.Millisecond}
calls := 0

err := rc.Do(context.Background(), func() error {
calls++
return nil
}, func(error) bool { return true })

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 1 {
t.Errorf("expected 1 call, got %d", calls)
}
}

func TestRetryConfig_RetriesOnTransientError(t *testing.T) {
rc := &RetryConfig{MaxAttempts: 3, InitDelay: 10 * time.Millisecond}
calls := 0

err := rc.Do(context.Background(), func() error {
calls++
if calls < 3 {
return errors.New("transient")
}
return nil
}, func(error) bool { return true })

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 3 {
t.Errorf("expected 3 calls, got %d", calls)
}
}

func TestRetryConfig_StopsOnNonRetryable(t *testing.T) {
rc := &RetryConfig{MaxAttempts: 5, InitDelay: 10 * time.Millisecond}
calls := 0
permanent := errors.New("permanent")

err := rc.Do(context.Background(), func() error {
calls++
return permanent
}, func(err error) bool { return err.Error() != "permanent" })

if err != permanent {
t.Fatalf("expected permanent error, got %v", err)
}
if calls != 1 {
t.Errorf("expected 1 call (no retry), got %d", calls)
}
}

func TestRetryConfig_ExponentialBackoff(t *testing.T) {
rc := &RetryConfig{MaxAttempts: 3, InitDelay: 50 * time.Millisecond}
calls := 0

start := time.Now()
rc.Do(context.Background(), func() error {
calls++
return errors.New("fail")
}, func(error) bool { return true })

elapsed := time.Since(start)
// Expected: 50ms + 100ms = 150ms minimum
if elapsed < 130*time.Millisecond {
t.Errorf("expected >= 150ms for exponential backoff, got %v", elapsed)
}
}

func TestRetryConfig_ContextCancellation(t *testing.T) {
rc := &RetryConfig{MaxAttempts: 10, InitDelay: 5 * time.Second}
ctx, cancel := context.WithCancel(context.Background())

go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()

err := rc.Do(ctx, func() error {
return errors.New("fail")
}, func(error) bool { return true })

if err == nil {
t.Error("expected error from cancelled context")
}
}
Loading