From 4d60b4796e99fcb62d0523f9c1dc25ed8ffb65c6 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:52:21 -0500 Subject: [PATCH 01/12] test: trying out new deps --- cache/cache_test.go | 6 +- cache/config/config_test.go | 44 +- cache/config/metricsprovider_mock_test.go | 515 ++++++++++++++++++++++ cache/config/mocks_gen_test.go | 5 + cache/memory/memory_test.go | 40 +- cache/redis/circuitbreaker_mock_test.go | 180 ++++++++ cache/redis/metricsprovider_mock_test.go | 515 ++++++++++++++++++++++ cache/redis/mocks_gen_test.go | 7 + cache/redis/redis_test.go | 30 +- cache/redis/redisclient_mock_test.go | 240 ++++++++++ cache/redis/unit_test.go | 379 ++++++++-------- go.mod | 12 +- go.sum | 18 +- 13 files changed, 1747 insertions(+), 244 deletions(-) create mode 100644 cache/config/metricsprovider_mock_test.go create mode 100644 cache/config/mocks_gen_test.go create mode 100644 cache/redis/circuitbreaker_mock_test.go create mode 100644 cache/redis/metricsprovider_mock_test.go create mode 100644 cache/redis/mocks_gen_test.go create mode 100644 cache/redis/redisclient_mock_test.go diff --git a/cache/cache_test.go b/cache/cache_test.go index 69052f8..8275e1b 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -3,7 +3,7 @@ package cache import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test/must" ) func TestErrNotFound(T *testing.T) { @@ -12,7 +12,7 @@ func TestErrNotFound(T *testing.T) { T.Run("is not nil", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrNotFound) - assert.Equal(t, "not found", ErrNotFound.Error()) + must.NotNil(t, ErrNotFound) + must.EqOp(t, "not found", ErrNotFound.Error()) }) } diff --git a/cache/config/config_test.go b/cache/config/config_test.go index 7f34279..0c7668a 100644 --- a/cache/config/config_test.go +++ b/cache/config/config_test.go @@ -1,18 +1,17 @@ package config import ( - "fmt" + "errors" "testing" "github.com/verygoodsoftwarenotvirus/platform/v5/cache/redis" circuitbreakingcfg "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/config" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -31,7 +30,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderMemory, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("redis provider with config", func(t *testing.T) { @@ -42,21 +41,21 @@ func TestConfig_ValidateWithContext(T *testing.T) { Redis: &redis.Config{QueueAddresses: []string{"localhost:6379"}}, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("redis provider missing config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderRedis} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("invalid provider name", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: "vault"} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } @@ -70,8 +69,8 @@ func TestProvideCache(T *testing.T) { Provider: ProviderMemory, }, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) }) T.Run("redis provider", func(t *testing.T) { @@ -91,8 +90,8 @@ func TestProvideCache(T *testing.T) { metrics.NewNoopMetricsProvider(), ) - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) }) T.Run("redis provider with cluster addresses", func(t *testing.T) { @@ -112,8 +111,8 @@ func TestProvideCache(T *testing.T) { metrics.NewNoopMetricsProvider(), ) - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) }) T.Run("redis provider with circuit breaker error", func(t *testing.T) { @@ -129,9 +128,12 @@ func TestProvideCache(T *testing.T) { }, } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", "redis-cache-breaker_circuit_breaker_tripped", []metric.Int64CounterOption(nil)). - Return(&mockmetrics.Int64Counter{}, fmt.Errorf("counter init failure")) + mp := &ProviderMock{ + NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "redis-cache-breaker_circuit_breaker_tripped", name) + return nil, errors.New("counter init failure") + }, + } c, err := ProvideCache[example]( t.Context(), @@ -141,9 +143,9 @@ func TestProvideCache(T *testing.T) { mp, ) - require.Error(t, err) - assert.Nil(t, c) - mp.AssertExpectations(t) + must.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("invalid provider", func(t *testing.T) { @@ -151,6 +153,6 @@ func TestProvideCache(T *testing.T) { _, err := ProvideCache[example](t.Context(), &Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - assert.Error(t, err) + test.Error(t, err) }) } diff --git a/cache/config/metricsprovider_mock_test.go b/cache/config/metricsprovider_mock_test.go new file mode 100644 index 0000000..c225860 --- /dev/null +++ b/cache/config/metricsprovider_mock_test.go @@ -0,0 +1,515 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package config + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + "go.opentelemetry.io/otel/metric" +) + +// Ensure, that ProviderMock does implement metrics.Provider. +// If this is not the case, regenerate this file with moq. +var _ metrics.Provider = &ProviderMock{} + +// ProviderMock is a mock implementation of metrics.Provider. +// +// func TestSomethingThatUsesProvider(t *testing.T) { +// +// // make and configure a mocked metrics.Provider +// mockedProvider := &ProviderMock{ +// MeterProviderFunc: func() metric.MeterProvider { +// panic("mock out the MeterProvider method") +// }, +// NewFloat64CounterFunc: func(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { +// panic("mock out the NewFloat64Counter method") +// }, +// NewFloat64GaugeFunc: func(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { +// panic("mock out the NewFloat64Gauge method") +// }, +// NewFloat64HistogramFunc: func(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { +// panic("mock out the NewFloat64Histogram method") +// }, +// NewFloat64UpDownCounterFunc: func(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { +// panic("mock out the NewFloat64UpDownCounter method") +// }, +// NewInt64CounterFunc: func(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { +// panic("mock out the NewInt64Counter method") +// }, +// NewInt64GaugeFunc: func(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { +// panic("mock out the NewInt64Gauge method") +// }, +// NewInt64HistogramFunc: func(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { +// panic("mock out the NewInt64Histogram method") +// }, +// NewInt64UpDownCounterFunc: func(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { +// panic("mock out the NewInt64UpDownCounter method") +// }, +// ShutdownFunc: func(ctx context.Context) error { +// panic("mock out the Shutdown method") +// }, +// } +// +// // use mockedProvider in code that requires metrics.Provider +// // and then make assertions. +// +// } +type ProviderMock struct { + // MeterProviderFunc mocks the MeterProvider method. + MeterProviderFunc func() metric.MeterProvider + + // NewFloat64CounterFunc mocks the NewFloat64Counter method. + NewFloat64CounterFunc func(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) + + // NewFloat64GaugeFunc mocks the NewFloat64Gauge method. + NewFloat64GaugeFunc func(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) + + // NewFloat64HistogramFunc mocks the NewFloat64Histogram method. + NewFloat64HistogramFunc func(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) + + // NewFloat64UpDownCounterFunc mocks the NewFloat64UpDownCounter method. + NewFloat64UpDownCounterFunc func(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) + + // NewInt64CounterFunc mocks the NewInt64Counter method. + NewInt64CounterFunc func(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) + + // NewInt64GaugeFunc mocks the NewInt64Gauge method. + NewInt64GaugeFunc func(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) + + // NewInt64HistogramFunc mocks the NewInt64Histogram method. + NewInt64HistogramFunc func(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) + + // NewInt64UpDownCounterFunc mocks the NewInt64UpDownCounter method. + NewInt64UpDownCounterFunc func(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) + + // ShutdownFunc mocks the Shutdown method. + ShutdownFunc func(ctx context.Context) error + + // calls tracks calls to the methods. + calls struct { + // MeterProvider holds details about calls to the MeterProvider method. + MeterProvider []struct { + } + // NewFloat64Counter holds details about calls to the NewFloat64Counter method. + NewFloat64Counter []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Float64CounterOption + } + // NewFloat64Gauge holds details about calls to the NewFloat64Gauge method. + NewFloat64Gauge []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Float64GaugeOption + } + // NewFloat64Histogram holds details about calls to the NewFloat64Histogram method. + NewFloat64Histogram []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Float64HistogramOption + } + // NewFloat64UpDownCounter holds details about calls to the NewFloat64UpDownCounter method. + NewFloat64UpDownCounter []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Float64UpDownCounterOption + } + // NewInt64Counter holds details about calls to the NewInt64Counter method. + NewInt64Counter []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Int64CounterOption + } + // NewInt64Gauge holds details about calls to the NewInt64Gauge method. + NewInt64Gauge []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Int64GaugeOption + } + // NewInt64Histogram holds details about calls to the NewInt64Histogram method. + NewInt64Histogram []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Int64HistogramOption + } + // NewInt64UpDownCounter holds details about calls to the NewInt64UpDownCounter method. + NewInt64UpDownCounter []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Int64UpDownCounterOption + } + // Shutdown holds details about calls to the Shutdown method. + Shutdown []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + } + lockMeterProvider sync.RWMutex + lockNewFloat64Counter sync.RWMutex + lockNewFloat64Gauge sync.RWMutex + lockNewFloat64Histogram sync.RWMutex + lockNewFloat64UpDownCounter sync.RWMutex + lockNewInt64Counter sync.RWMutex + lockNewInt64Gauge sync.RWMutex + lockNewInt64Histogram sync.RWMutex + lockNewInt64UpDownCounter sync.RWMutex + lockShutdown sync.RWMutex +} + +// MeterProvider calls MeterProviderFunc. +func (mock *ProviderMock) MeterProvider() metric.MeterProvider { + if mock.MeterProviderFunc == nil { + panic("ProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") + } + callInfo := struct { + }{} + mock.lockMeterProvider.Lock() + mock.calls.MeterProvider = append(mock.calls.MeterProvider, callInfo) + mock.lockMeterProvider.Unlock() + return mock.MeterProviderFunc() +} + +// MeterProviderCalls gets all the calls that were made to MeterProvider. +// Check the length with: +// +// len(mockedProvider.MeterProviderCalls()) +func (mock *ProviderMock) MeterProviderCalls() []struct { +} { + var calls []struct { + } + mock.lockMeterProvider.RLock() + calls = mock.calls.MeterProvider + mock.lockMeterProvider.RUnlock() + return calls +} + +// NewFloat64Counter calls NewFloat64CounterFunc. +func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { + if mock.NewFloat64CounterFunc == nil { + panic("ProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") + } + callInfo := struct { + Name string + Options []metric.Float64CounterOption + }{ + Name: name, + Options: options, + } + mock.lockNewFloat64Counter.Lock() + mock.calls.NewFloat64Counter = append(mock.calls.NewFloat64Counter, callInfo) + mock.lockNewFloat64Counter.Unlock() + return mock.NewFloat64CounterFunc(name, options...) +} + +// NewFloat64CounterCalls gets all the calls that were made to NewFloat64Counter. +// Check the length with: +// +// len(mockedProvider.NewFloat64CounterCalls()) +func (mock *ProviderMock) NewFloat64CounterCalls() []struct { + Name string + Options []metric.Float64CounterOption +} { + var calls []struct { + Name string + Options []metric.Float64CounterOption + } + mock.lockNewFloat64Counter.RLock() + calls = mock.calls.NewFloat64Counter + mock.lockNewFloat64Counter.RUnlock() + return calls +} + +// NewFloat64Gauge calls NewFloat64GaugeFunc. +func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { + if mock.NewFloat64GaugeFunc == nil { + panic("ProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") + } + callInfo := struct { + Name string + Options []metric.Float64GaugeOption + }{ + Name: name, + Options: options, + } + mock.lockNewFloat64Gauge.Lock() + mock.calls.NewFloat64Gauge = append(mock.calls.NewFloat64Gauge, callInfo) + mock.lockNewFloat64Gauge.Unlock() + return mock.NewFloat64GaugeFunc(name, options...) +} + +// NewFloat64GaugeCalls gets all the calls that were made to NewFloat64Gauge. +// Check the length with: +// +// len(mockedProvider.NewFloat64GaugeCalls()) +func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { + Name string + Options []metric.Float64GaugeOption +} { + var calls []struct { + Name string + Options []metric.Float64GaugeOption + } + mock.lockNewFloat64Gauge.RLock() + calls = mock.calls.NewFloat64Gauge + mock.lockNewFloat64Gauge.RUnlock() + return calls +} + +// NewFloat64Histogram calls NewFloat64HistogramFunc. +func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + if mock.NewFloat64HistogramFunc == nil { + panic("ProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") + } + callInfo := struct { + Name string + Options []metric.Float64HistogramOption + }{ + Name: name, + Options: options, + } + mock.lockNewFloat64Histogram.Lock() + mock.calls.NewFloat64Histogram = append(mock.calls.NewFloat64Histogram, callInfo) + mock.lockNewFloat64Histogram.Unlock() + return mock.NewFloat64HistogramFunc(name, options...) +} + +// NewFloat64HistogramCalls gets all the calls that were made to NewFloat64Histogram. +// Check the length with: +// +// len(mockedProvider.NewFloat64HistogramCalls()) +func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { + Name string + Options []metric.Float64HistogramOption +} { + var calls []struct { + Name string + Options []metric.Float64HistogramOption + } + mock.lockNewFloat64Histogram.RLock() + calls = mock.calls.NewFloat64Histogram + mock.lockNewFloat64Histogram.RUnlock() + return calls +} + +// NewFloat64UpDownCounter calls NewFloat64UpDownCounterFunc. +func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { + if mock.NewFloat64UpDownCounterFunc == nil { + panic("ProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") + } + callInfo := struct { + Name string + Options []metric.Float64UpDownCounterOption + }{ + Name: name, + Options: options, + } + mock.lockNewFloat64UpDownCounter.Lock() + mock.calls.NewFloat64UpDownCounter = append(mock.calls.NewFloat64UpDownCounter, callInfo) + mock.lockNewFloat64UpDownCounter.Unlock() + return mock.NewFloat64UpDownCounterFunc(name, options...) +} + +// NewFloat64UpDownCounterCalls gets all the calls that were made to NewFloat64UpDownCounter. +// Check the length with: +// +// len(mockedProvider.NewFloat64UpDownCounterCalls()) +func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { + Name string + Options []metric.Float64UpDownCounterOption +} { + var calls []struct { + Name string + Options []metric.Float64UpDownCounterOption + } + mock.lockNewFloat64UpDownCounter.RLock() + calls = mock.calls.NewFloat64UpDownCounter + mock.lockNewFloat64UpDownCounter.RUnlock() + return calls +} + +// NewInt64Counter calls NewInt64CounterFunc. +func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + if mock.NewInt64CounterFunc == nil { + panic("ProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") + } + callInfo := struct { + Name string + Options []metric.Int64CounterOption + }{ + Name: name, + Options: options, + } + mock.lockNewInt64Counter.Lock() + mock.calls.NewInt64Counter = append(mock.calls.NewInt64Counter, callInfo) + mock.lockNewInt64Counter.Unlock() + return mock.NewInt64CounterFunc(name, options...) +} + +// NewInt64CounterCalls gets all the calls that were made to NewInt64Counter. +// Check the length with: +// +// len(mockedProvider.NewInt64CounterCalls()) +func (mock *ProviderMock) NewInt64CounterCalls() []struct { + Name string + Options []metric.Int64CounterOption +} { + var calls []struct { + Name string + Options []metric.Int64CounterOption + } + mock.lockNewInt64Counter.RLock() + calls = mock.calls.NewInt64Counter + mock.lockNewInt64Counter.RUnlock() + return calls +} + +// NewInt64Gauge calls NewInt64GaugeFunc. +func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { + if mock.NewInt64GaugeFunc == nil { + panic("ProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") + } + callInfo := struct { + Name string + Options []metric.Int64GaugeOption + }{ + Name: name, + Options: options, + } + mock.lockNewInt64Gauge.Lock() + mock.calls.NewInt64Gauge = append(mock.calls.NewInt64Gauge, callInfo) + mock.lockNewInt64Gauge.Unlock() + return mock.NewInt64GaugeFunc(name, options...) +} + +// NewInt64GaugeCalls gets all the calls that were made to NewInt64Gauge. +// Check the length with: +// +// len(mockedProvider.NewInt64GaugeCalls()) +func (mock *ProviderMock) NewInt64GaugeCalls() []struct { + Name string + Options []metric.Int64GaugeOption +} { + var calls []struct { + Name string + Options []metric.Int64GaugeOption + } + mock.lockNewInt64Gauge.RLock() + calls = mock.calls.NewInt64Gauge + mock.lockNewInt64Gauge.RUnlock() + return calls +} + +// NewInt64Histogram calls NewInt64HistogramFunc. +func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { + if mock.NewInt64HistogramFunc == nil { + panic("ProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") + } + callInfo := struct { + Name string + Options []metric.Int64HistogramOption + }{ + Name: name, + Options: options, + } + mock.lockNewInt64Histogram.Lock() + mock.calls.NewInt64Histogram = append(mock.calls.NewInt64Histogram, callInfo) + mock.lockNewInt64Histogram.Unlock() + return mock.NewInt64HistogramFunc(name, options...) +} + +// NewInt64HistogramCalls gets all the calls that were made to NewInt64Histogram. +// Check the length with: +// +// len(mockedProvider.NewInt64HistogramCalls()) +func (mock *ProviderMock) NewInt64HistogramCalls() []struct { + Name string + Options []metric.Int64HistogramOption +} { + var calls []struct { + Name string + Options []metric.Int64HistogramOption + } + mock.lockNewInt64Histogram.RLock() + calls = mock.calls.NewInt64Histogram + mock.lockNewInt64Histogram.RUnlock() + return calls +} + +// NewInt64UpDownCounter calls NewInt64UpDownCounterFunc. +func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { + if mock.NewInt64UpDownCounterFunc == nil { + panic("ProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") + } + callInfo := struct { + Name string + Options []metric.Int64UpDownCounterOption + }{ + Name: name, + Options: options, + } + mock.lockNewInt64UpDownCounter.Lock() + mock.calls.NewInt64UpDownCounter = append(mock.calls.NewInt64UpDownCounter, callInfo) + mock.lockNewInt64UpDownCounter.Unlock() + return mock.NewInt64UpDownCounterFunc(name, options...) +} + +// NewInt64UpDownCounterCalls gets all the calls that were made to NewInt64UpDownCounter. +// Check the length with: +// +// len(mockedProvider.NewInt64UpDownCounterCalls()) +func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { + Name string + Options []metric.Int64UpDownCounterOption +} { + var calls []struct { + Name string + Options []metric.Int64UpDownCounterOption + } + mock.lockNewInt64UpDownCounter.RLock() + calls = mock.calls.NewInt64UpDownCounter + mock.lockNewInt64UpDownCounter.RUnlock() + return calls +} + +// Shutdown calls ShutdownFunc. +func (mock *ProviderMock) Shutdown(ctx context.Context) error { + if mock.ShutdownFunc == nil { + panic("ProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockShutdown.Lock() + mock.calls.Shutdown = append(mock.calls.Shutdown, callInfo) + mock.lockShutdown.Unlock() + return mock.ShutdownFunc(ctx) +} + +// ShutdownCalls gets all the calls that were made to Shutdown. +// Check the length with: +// +// len(mockedProvider.ShutdownCalls()) +func (mock *ProviderMock) ShutdownCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockShutdown.RLock() + calls = mock.calls.Shutdown + mock.lockShutdown.RUnlock() + return calls +} diff --git a/cache/config/mocks_gen_test.go b/cache/config/mocks_gen_test.go new file mode 100644 index 0000000..7ad25d3 --- /dev/null +++ b/cache/config/mocks_gen_test.go @@ -0,0 +1,5 @@ +package config + +// Regenerate mocks via `go generate ./cache/config/...`. + +//go:generate go tool github.com/matryer/moq -out metricsprovider_mock_test.go -pkg config -rm -fmt goimports ../../observability/metrics Provider diff --git a/cache/memory/memory_test.go b/cache/memory/memory_test.go index 6cababa..96c8209 100644 --- a/cache/memory/memory_test.go +++ b/cache/memory/memory_test.go @@ -3,8 +3,8 @@ package memory import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) const ( @@ -22,8 +22,8 @@ func Test_newInMemoryCache(T *testing.T) { t.Parallel() actual, err := NewInMemoryCache[example](nil, nil, nil) - require.NoError(t, err) - assert.NotNil(t, actual) + must.NoError(t, err) + test.NotNil(t, actual) }) } @@ -35,14 +35,14 @@ func Test_inMemoryCacheImpl_Get(T *testing.T) { ctx := t.Context() c, err := NewInMemoryCache[example](nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) expected := &example{Name: t.Name()} - assert.NoError(t, c.Set(ctx, exampleKey, expected)) + test.NoError(t, c.Set(ctx, exampleKey, expected)) actual, err := c.Get(ctx, exampleKey) - assert.Equal(t, expected, actual) - assert.NoError(t, err) + test.Eq(t, expected, actual) + test.NoError(t, err) }) } @@ -54,11 +54,11 @@ func Test_inMemoryCacheImpl_Set(T *testing.T) { ctx := t.Context() c, err := NewInMemoryCache[example](nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) - assert.Len(t, c.(*inMemoryCacheImpl[example]).cache, 0) - assert.NoError(t, c.Set(ctx, exampleKey, &example{Name: t.Name()})) - assert.Len(t, c.(*inMemoryCacheImpl[example]).cache, 1) + test.MapLen(t, 0, c.(*inMemoryCacheImpl[example]).cache) + test.NoError(t, c.Set(ctx, exampleKey, &example{Name: t.Name()})) + test.MapLen(t, 1, c.(*inMemoryCacheImpl[example]).cache) }) } @@ -70,13 +70,13 @@ func Test_inMemoryCacheImpl_Delete(T *testing.T) { ctx := t.Context() c, err := NewInMemoryCache[example](nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) - assert.Len(t, c.(*inMemoryCacheImpl[example]).cache, 0) - assert.NoError(t, c.Set(ctx, exampleKey, &example{Name: t.Name()})) - assert.Len(t, c.(*inMemoryCacheImpl[example]).cache, 1) - assert.NoError(t, c.Delete(ctx, exampleKey)) - assert.Len(t, c.(*inMemoryCacheImpl[example]).cache, 0) + test.MapLen(t, 0, c.(*inMemoryCacheImpl[example]).cache) + test.NoError(t, c.Set(ctx, exampleKey, &example{Name: t.Name()})) + test.MapLen(t, 1, c.(*inMemoryCacheImpl[example]).cache) + test.NoError(t, c.Delete(ctx, exampleKey)) + test.MapLen(t, 0, c.(*inMemoryCacheImpl[example]).cache) }) } @@ -87,7 +87,7 @@ func Test_inMemoryCacheImpl_Ping(T *testing.T) { t.Parallel() c, err := NewInMemoryCache[example](nil, nil, nil) - require.NoError(t, err) - assert.NoError(t, c.Ping(t.Context())) + must.NoError(t, err) + test.NoError(t, c.Ping(t.Context())) }) } diff --git a/cache/redis/circuitbreaker_mock_test.go b/cache/redis/circuitbreaker_mock_test.go new file mode 100644 index 0000000..8b474ee --- /dev/null +++ b/cache/redis/circuitbreaker_mock_test.go @@ -0,0 +1,180 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package redis + +import ( + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking" +) + +// Ensure, that CircuitBreakerMock does implement circuitbreaking.CircuitBreaker. +// If this is not the case, regenerate this file with moq. +var _ circuitbreaking.CircuitBreaker = &CircuitBreakerMock{} + +// CircuitBreakerMock is a mock implementation of circuitbreaking.CircuitBreaker. +// +// func TestSomethingThatUsesCircuitBreaker(t *testing.T) { +// +// // make and configure a mocked circuitbreaking.CircuitBreaker +// mockedCircuitBreaker := &CircuitBreakerMock{ +// CanProceedFunc: func() bool { +// panic("mock out the CanProceed method") +// }, +// CannotProceedFunc: func() bool { +// panic("mock out the CannotProceed method") +// }, +// FailedFunc: func() { +// panic("mock out the Failed method") +// }, +// SucceededFunc: func() { +// panic("mock out the Succeeded method") +// }, +// } +// +// // use mockedCircuitBreaker in code that requires circuitbreaking.CircuitBreaker +// // and then make assertions. +// +// } +type CircuitBreakerMock struct { + // CanProceedFunc mocks the CanProceed method. + CanProceedFunc func() bool + + // CannotProceedFunc mocks the CannotProceed method. + CannotProceedFunc func() bool + + // FailedFunc mocks the Failed method. + FailedFunc func() + + // SucceededFunc mocks the Succeeded method. + SucceededFunc func() + + // calls tracks calls to the methods. + calls struct { + // CanProceed holds details about calls to the CanProceed method. + CanProceed []struct { + } + // CannotProceed holds details about calls to the CannotProceed method. + CannotProceed []struct { + } + // Failed holds details about calls to the Failed method. + Failed []struct { + } + // Succeeded holds details about calls to the Succeeded method. + Succeeded []struct { + } + } + lockCanProceed sync.RWMutex + lockCannotProceed sync.RWMutex + lockFailed sync.RWMutex + lockSucceeded sync.RWMutex +} + +// CanProceed calls CanProceedFunc. +func (mock *CircuitBreakerMock) CanProceed() bool { + if mock.CanProceedFunc == nil { + panic("CircuitBreakerMock.CanProceedFunc: method is nil but CircuitBreaker.CanProceed was just called") + } + callInfo := struct { + }{} + mock.lockCanProceed.Lock() + mock.calls.CanProceed = append(mock.calls.CanProceed, callInfo) + mock.lockCanProceed.Unlock() + return mock.CanProceedFunc() +} + +// CanProceedCalls gets all the calls that were made to CanProceed. +// Check the length with: +// +// len(mockedCircuitBreaker.CanProceedCalls()) +func (mock *CircuitBreakerMock) CanProceedCalls() []struct { +} { + var calls []struct { + } + mock.lockCanProceed.RLock() + calls = mock.calls.CanProceed + mock.lockCanProceed.RUnlock() + return calls +} + +// CannotProceed calls CannotProceedFunc. +func (mock *CircuitBreakerMock) CannotProceed() bool { + if mock.CannotProceedFunc == nil { + panic("CircuitBreakerMock.CannotProceedFunc: method is nil but CircuitBreaker.CannotProceed was just called") + } + callInfo := struct { + }{} + mock.lockCannotProceed.Lock() + mock.calls.CannotProceed = append(mock.calls.CannotProceed, callInfo) + mock.lockCannotProceed.Unlock() + return mock.CannotProceedFunc() +} + +// CannotProceedCalls gets all the calls that were made to CannotProceed. +// Check the length with: +// +// len(mockedCircuitBreaker.CannotProceedCalls()) +func (mock *CircuitBreakerMock) CannotProceedCalls() []struct { +} { + var calls []struct { + } + mock.lockCannotProceed.RLock() + calls = mock.calls.CannotProceed + mock.lockCannotProceed.RUnlock() + return calls +} + +// Failed calls FailedFunc. +func (mock *CircuitBreakerMock) Failed() { + if mock.FailedFunc == nil { + panic("CircuitBreakerMock.FailedFunc: method is nil but CircuitBreaker.Failed was just called") + } + callInfo := struct { + }{} + mock.lockFailed.Lock() + mock.calls.Failed = append(mock.calls.Failed, callInfo) + mock.lockFailed.Unlock() + mock.FailedFunc() +} + +// FailedCalls gets all the calls that were made to Failed. +// Check the length with: +// +// len(mockedCircuitBreaker.FailedCalls()) +func (mock *CircuitBreakerMock) FailedCalls() []struct { +} { + var calls []struct { + } + mock.lockFailed.RLock() + calls = mock.calls.Failed + mock.lockFailed.RUnlock() + return calls +} + +// Succeeded calls SucceededFunc. +func (mock *CircuitBreakerMock) Succeeded() { + if mock.SucceededFunc == nil { + panic("CircuitBreakerMock.SucceededFunc: method is nil but CircuitBreaker.Succeeded was just called") + } + callInfo := struct { + }{} + mock.lockSucceeded.Lock() + mock.calls.Succeeded = append(mock.calls.Succeeded, callInfo) + mock.lockSucceeded.Unlock() + mock.SucceededFunc() +} + +// SucceededCalls gets all the calls that were made to Succeeded. +// Check the length with: +// +// len(mockedCircuitBreaker.SucceededCalls()) +func (mock *CircuitBreakerMock) SucceededCalls() []struct { +} { + var calls []struct { + } + mock.lockSucceeded.RLock() + calls = mock.calls.Succeeded + mock.lockSucceeded.RUnlock() + return calls +} diff --git a/cache/redis/metricsprovider_mock_test.go b/cache/redis/metricsprovider_mock_test.go new file mode 100644 index 0000000..247d6e8 --- /dev/null +++ b/cache/redis/metricsprovider_mock_test.go @@ -0,0 +1,515 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package redis + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + "go.opentelemetry.io/otel/metric" +) + +// Ensure, that ProviderMock does implement metrics.Provider. +// If this is not the case, regenerate this file with moq. +var _ metrics.Provider = &ProviderMock{} + +// ProviderMock is a mock implementation of metrics.Provider. +// +// func TestSomethingThatUsesProvider(t *testing.T) { +// +// // make and configure a mocked metrics.Provider +// mockedProvider := &ProviderMock{ +// MeterProviderFunc: func() metric.MeterProvider { +// panic("mock out the MeterProvider method") +// }, +// NewFloat64CounterFunc: func(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { +// panic("mock out the NewFloat64Counter method") +// }, +// NewFloat64GaugeFunc: func(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { +// panic("mock out the NewFloat64Gauge method") +// }, +// NewFloat64HistogramFunc: func(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { +// panic("mock out the NewFloat64Histogram method") +// }, +// NewFloat64UpDownCounterFunc: func(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { +// panic("mock out the NewFloat64UpDownCounter method") +// }, +// NewInt64CounterFunc: func(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { +// panic("mock out the NewInt64Counter method") +// }, +// NewInt64GaugeFunc: func(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { +// panic("mock out the NewInt64Gauge method") +// }, +// NewInt64HistogramFunc: func(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { +// panic("mock out the NewInt64Histogram method") +// }, +// NewInt64UpDownCounterFunc: func(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { +// panic("mock out the NewInt64UpDownCounter method") +// }, +// ShutdownFunc: func(ctx context.Context) error { +// panic("mock out the Shutdown method") +// }, +// } +// +// // use mockedProvider in code that requires metrics.Provider +// // and then make assertions. +// +// } +type ProviderMock struct { + // MeterProviderFunc mocks the MeterProvider method. + MeterProviderFunc func() metric.MeterProvider + + // NewFloat64CounterFunc mocks the NewFloat64Counter method. + NewFloat64CounterFunc func(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) + + // NewFloat64GaugeFunc mocks the NewFloat64Gauge method. + NewFloat64GaugeFunc func(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) + + // NewFloat64HistogramFunc mocks the NewFloat64Histogram method. + NewFloat64HistogramFunc func(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) + + // NewFloat64UpDownCounterFunc mocks the NewFloat64UpDownCounter method. + NewFloat64UpDownCounterFunc func(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) + + // NewInt64CounterFunc mocks the NewInt64Counter method. + NewInt64CounterFunc func(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) + + // NewInt64GaugeFunc mocks the NewInt64Gauge method. + NewInt64GaugeFunc func(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) + + // NewInt64HistogramFunc mocks the NewInt64Histogram method. + NewInt64HistogramFunc func(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) + + // NewInt64UpDownCounterFunc mocks the NewInt64UpDownCounter method. + NewInt64UpDownCounterFunc func(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) + + // ShutdownFunc mocks the Shutdown method. + ShutdownFunc func(ctx context.Context) error + + // calls tracks calls to the methods. + calls struct { + // MeterProvider holds details about calls to the MeterProvider method. + MeterProvider []struct { + } + // NewFloat64Counter holds details about calls to the NewFloat64Counter method. + NewFloat64Counter []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Float64CounterOption + } + // NewFloat64Gauge holds details about calls to the NewFloat64Gauge method. + NewFloat64Gauge []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Float64GaugeOption + } + // NewFloat64Histogram holds details about calls to the NewFloat64Histogram method. + NewFloat64Histogram []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Float64HistogramOption + } + // NewFloat64UpDownCounter holds details about calls to the NewFloat64UpDownCounter method. + NewFloat64UpDownCounter []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Float64UpDownCounterOption + } + // NewInt64Counter holds details about calls to the NewInt64Counter method. + NewInt64Counter []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Int64CounterOption + } + // NewInt64Gauge holds details about calls to the NewInt64Gauge method. + NewInt64Gauge []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Int64GaugeOption + } + // NewInt64Histogram holds details about calls to the NewInt64Histogram method. + NewInt64Histogram []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Int64HistogramOption + } + // NewInt64UpDownCounter holds details about calls to the NewInt64UpDownCounter method. + NewInt64UpDownCounter []struct { + // Name is the name argument value. + Name string + // Options is the options argument value. + Options []metric.Int64UpDownCounterOption + } + // Shutdown holds details about calls to the Shutdown method. + Shutdown []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + } + lockMeterProvider sync.RWMutex + lockNewFloat64Counter sync.RWMutex + lockNewFloat64Gauge sync.RWMutex + lockNewFloat64Histogram sync.RWMutex + lockNewFloat64UpDownCounter sync.RWMutex + lockNewInt64Counter sync.RWMutex + lockNewInt64Gauge sync.RWMutex + lockNewInt64Histogram sync.RWMutex + lockNewInt64UpDownCounter sync.RWMutex + lockShutdown sync.RWMutex +} + +// MeterProvider calls MeterProviderFunc. +func (mock *ProviderMock) MeterProvider() metric.MeterProvider { + if mock.MeterProviderFunc == nil { + panic("ProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") + } + callInfo := struct { + }{} + mock.lockMeterProvider.Lock() + mock.calls.MeterProvider = append(mock.calls.MeterProvider, callInfo) + mock.lockMeterProvider.Unlock() + return mock.MeterProviderFunc() +} + +// MeterProviderCalls gets all the calls that were made to MeterProvider. +// Check the length with: +// +// len(mockedProvider.MeterProviderCalls()) +func (mock *ProviderMock) MeterProviderCalls() []struct { +} { + var calls []struct { + } + mock.lockMeterProvider.RLock() + calls = mock.calls.MeterProvider + mock.lockMeterProvider.RUnlock() + return calls +} + +// NewFloat64Counter calls NewFloat64CounterFunc. +func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { + if mock.NewFloat64CounterFunc == nil { + panic("ProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") + } + callInfo := struct { + Name string + Options []metric.Float64CounterOption + }{ + Name: name, + Options: options, + } + mock.lockNewFloat64Counter.Lock() + mock.calls.NewFloat64Counter = append(mock.calls.NewFloat64Counter, callInfo) + mock.lockNewFloat64Counter.Unlock() + return mock.NewFloat64CounterFunc(name, options...) +} + +// NewFloat64CounterCalls gets all the calls that were made to NewFloat64Counter. +// Check the length with: +// +// len(mockedProvider.NewFloat64CounterCalls()) +func (mock *ProviderMock) NewFloat64CounterCalls() []struct { + Name string + Options []metric.Float64CounterOption +} { + var calls []struct { + Name string + Options []metric.Float64CounterOption + } + mock.lockNewFloat64Counter.RLock() + calls = mock.calls.NewFloat64Counter + mock.lockNewFloat64Counter.RUnlock() + return calls +} + +// NewFloat64Gauge calls NewFloat64GaugeFunc. +func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { + if mock.NewFloat64GaugeFunc == nil { + panic("ProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") + } + callInfo := struct { + Name string + Options []metric.Float64GaugeOption + }{ + Name: name, + Options: options, + } + mock.lockNewFloat64Gauge.Lock() + mock.calls.NewFloat64Gauge = append(mock.calls.NewFloat64Gauge, callInfo) + mock.lockNewFloat64Gauge.Unlock() + return mock.NewFloat64GaugeFunc(name, options...) +} + +// NewFloat64GaugeCalls gets all the calls that were made to NewFloat64Gauge. +// Check the length with: +// +// len(mockedProvider.NewFloat64GaugeCalls()) +func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { + Name string + Options []metric.Float64GaugeOption +} { + var calls []struct { + Name string + Options []metric.Float64GaugeOption + } + mock.lockNewFloat64Gauge.RLock() + calls = mock.calls.NewFloat64Gauge + mock.lockNewFloat64Gauge.RUnlock() + return calls +} + +// NewFloat64Histogram calls NewFloat64HistogramFunc. +func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + if mock.NewFloat64HistogramFunc == nil { + panic("ProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") + } + callInfo := struct { + Name string + Options []metric.Float64HistogramOption + }{ + Name: name, + Options: options, + } + mock.lockNewFloat64Histogram.Lock() + mock.calls.NewFloat64Histogram = append(mock.calls.NewFloat64Histogram, callInfo) + mock.lockNewFloat64Histogram.Unlock() + return mock.NewFloat64HistogramFunc(name, options...) +} + +// NewFloat64HistogramCalls gets all the calls that were made to NewFloat64Histogram. +// Check the length with: +// +// len(mockedProvider.NewFloat64HistogramCalls()) +func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { + Name string + Options []metric.Float64HistogramOption +} { + var calls []struct { + Name string + Options []metric.Float64HistogramOption + } + mock.lockNewFloat64Histogram.RLock() + calls = mock.calls.NewFloat64Histogram + mock.lockNewFloat64Histogram.RUnlock() + return calls +} + +// NewFloat64UpDownCounter calls NewFloat64UpDownCounterFunc. +func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { + if mock.NewFloat64UpDownCounterFunc == nil { + panic("ProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") + } + callInfo := struct { + Name string + Options []metric.Float64UpDownCounterOption + }{ + Name: name, + Options: options, + } + mock.lockNewFloat64UpDownCounter.Lock() + mock.calls.NewFloat64UpDownCounter = append(mock.calls.NewFloat64UpDownCounter, callInfo) + mock.lockNewFloat64UpDownCounter.Unlock() + return mock.NewFloat64UpDownCounterFunc(name, options...) +} + +// NewFloat64UpDownCounterCalls gets all the calls that were made to NewFloat64UpDownCounter. +// Check the length with: +// +// len(mockedProvider.NewFloat64UpDownCounterCalls()) +func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { + Name string + Options []metric.Float64UpDownCounterOption +} { + var calls []struct { + Name string + Options []metric.Float64UpDownCounterOption + } + mock.lockNewFloat64UpDownCounter.RLock() + calls = mock.calls.NewFloat64UpDownCounter + mock.lockNewFloat64UpDownCounter.RUnlock() + return calls +} + +// NewInt64Counter calls NewInt64CounterFunc. +func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + if mock.NewInt64CounterFunc == nil { + panic("ProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") + } + callInfo := struct { + Name string + Options []metric.Int64CounterOption + }{ + Name: name, + Options: options, + } + mock.lockNewInt64Counter.Lock() + mock.calls.NewInt64Counter = append(mock.calls.NewInt64Counter, callInfo) + mock.lockNewInt64Counter.Unlock() + return mock.NewInt64CounterFunc(name, options...) +} + +// NewInt64CounterCalls gets all the calls that were made to NewInt64Counter. +// Check the length with: +// +// len(mockedProvider.NewInt64CounterCalls()) +func (mock *ProviderMock) NewInt64CounterCalls() []struct { + Name string + Options []metric.Int64CounterOption +} { + var calls []struct { + Name string + Options []metric.Int64CounterOption + } + mock.lockNewInt64Counter.RLock() + calls = mock.calls.NewInt64Counter + mock.lockNewInt64Counter.RUnlock() + return calls +} + +// NewInt64Gauge calls NewInt64GaugeFunc. +func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { + if mock.NewInt64GaugeFunc == nil { + panic("ProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") + } + callInfo := struct { + Name string + Options []metric.Int64GaugeOption + }{ + Name: name, + Options: options, + } + mock.lockNewInt64Gauge.Lock() + mock.calls.NewInt64Gauge = append(mock.calls.NewInt64Gauge, callInfo) + mock.lockNewInt64Gauge.Unlock() + return mock.NewInt64GaugeFunc(name, options...) +} + +// NewInt64GaugeCalls gets all the calls that were made to NewInt64Gauge. +// Check the length with: +// +// len(mockedProvider.NewInt64GaugeCalls()) +func (mock *ProviderMock) NewInt64GaugeCalls() []struct { + Name string + Options []metric.Int64GaugeOption +} { + var calls []struct { + Name string + Options []metric.Int64GaugeOption + } + mock.lockNewInt64Gauge.RLock() + calls = mock.calls.NewInt64Gauge + mock.lockNewInt64Gauge.RUnlock() + return calls +} + +// NewInt64Histogram calls NewInt64HistogramFunc. +func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { + if mock.NewInt64HistogramFunc == nil { + panic("ProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") + } + callInfo := struct { + Name string + Options []metric.Int64HistogramOption + }{ + Name: name, + Options: options, + } + mock.lockNewInt64Histogram.Lock() + mock.calls.NewInt64Histogram = append(mock.calls.NewInt64Histogram, callInfo) + mock.lockNewInt64Histogram.Unlock() + return mock.NewInt64HistogramFunc(name, options...) +} + +// NewInt64HistogramCalls gets all the calls that were made to NewInt64Histogram. +// Check the length with: +// +// len(mockedProvider.NewInt64HistogramCalls()) +func (mock *ProviderMock) NewInt64HistogramCalls() []struct { + Name string + Options []metric.Int64HistogramOption +} { + var calls []struct { + Name string + Options []metric.Int64HistogramOption + } + mock.lockNewInt64Histogram.RLock() + calls = mock.calls.NewInt64Histogram + mock.lockNewInt64Histogram.RUnlock() + return calls +} + +// NewInt64UpDownCounter calls NewInt64UpDownCounterFunc. +func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { + if mock.NewInt64UpDownCounterFunc == nil { + panic("ProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") + } + callInfo := struct { + Name string + Options []metric.Int64UpDownCounterOption + }{ + Name: name, + Options: options, + } + mock.lockNewInt64UpDownCounter.Lock() + mock.calls.NewInt64UpDownCounter = append(mock.calls.NewInt64UpDownCounter, callInfo) + mock.lockNewInt64UpDownCounter.Unlock() + return mock.NewInt64UpDownCounterFunc(name, options...) +} + +// NewInt64UpDownCounterCalls gets all the calls that were made to NewInt64UpDownCounter. +// Check the length with: +// +// len(mockedProvider.NewInt64UpDownCounterCalls()) +func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { + Name string + Options []metric.Int64UpDownCounterOption +} { + var calls []struct { + Name string + Options []metric.Int64UpDownCounterOption + } + mock.lockNewInt64UpDownCounter.RLock() + calls = mock.calls.NewInt64UpDownCounter + mock.lockNewInt64UpDownCounter.RUnlock() + return calls +} + +// Shutdown calls ShutdownFunc. +func (mock *ProviderMock) Shutdown(ctx context.Context) error { + if mock.ShutdownFunc == nil { + panic("ProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockShutdown.Lock() + mock.calls.Shutdown = append(mock.calls.Shutdown, callInfo) + mock.lockShutdown.Unlock() + return mock.ShutdownFunc(ctx) +} + +// ShutdownCalls gets all the calls that were made to Shutdown. +// Check the length with: +// +// len(mockedProvider.ShutdownCalls()) +func (mock *ProviderMock) ShutdownCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockShutdown.RLock() + calls = mock.calls.Shutdown + mock.lockShutdown.RUnlock() + return calls +} diff --git a/cache/redis/mocks_gen_test.go b/cache/redis/mocks_gen_test.go new file mode 100644 index 0000000..b0a8e9d --- /dev/null +++ b/cache/redis/mocks_gen_test.go @@ -0,0 +1,7 @@ +package redis + +// Regenerate mocks via `go generate ./cache/redis/...`. + +//go:generate go tool github.com/matryer/moq -out redisclient_mock_test.go -pkg redis -rm -fmt goimports . redisClient +//go:generate go tool github.com/matryer/moq -out metricsprovider_mock_test.go -pkg redis -rm -fmt goimports ../../observability/metrics Provider +//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock_test.go -pkg redis -rm -fmt goimports ../../circuitbreaking CircuitBreaker diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 0bcbeff..53299d6 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" rediscontainers "github.com/testcontainers/testcontainers-go/modules/redis" ) @@ -38,7 +38,7 @@ func buildContainerBackedRedisConfig(t *testing.T) (config *Config, shutdownFunc time.Sleep(100 * time.Millisecond) redisAddress, err := redisContainer.ConnectionString(containerCtx) - require.NoError(t, err) + must.NoError(t, err) cfg := &Config{ QueueAddresses: []string{ @@ -63,17 +63,17 @@ func Test_redisCacheImpl_Get(T *testing.T) { cfg, containerShutdown := buildContainerBackedRedisConfig(t) defer func() { - assert.NoError(t, containerShutdown(ctx)) + test.NoError(t, containerShutdown(ctx)) }() c, err := NewRedisCache[example](cfg, 0, nil, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) exampleContent := &example{Name: t.Name()} - assert.NoError(t, c.Set(ctx, exampleKey, exampleContent)) + test.NoError(t, c.Set(ctx, exampleKey, exampleContent)) actual, getErr := c.Get(ctx, exampleKey) - assert.Equal(t, exampleContent, actual) - assert.NoError(t, getErr) + test.Eq(t, exampleContent, actual) + test.NoError(t, getErr) }) } @@ -87,13 +87,13 @@ func Test_redisCacheImpl_Set(T *testing.T) { cfg, containerShutdown := buildContainerBackedRedisConfig(t) defer func() { - assert.NoError(t, containerShutdown(ctx)) + test.NoError(t, containerShutdown(ctx)) }() c, err := NewRedisCache[example](cfg, 0, nil, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) exampleContent := &example{Name: t.Name()} - assert.NoError(t, c.Set(ctx, exampleKey, exampleContent)) + test.NoError(t, c.Set(ctx, exampleKey, exampleContent)) }) } @@ -107,14 +107,14 @@ func Test_redisCacheImpl_Delete(T *testing.T) { cfg, containerShutdown := buildContainerBackedRedisConfig(t) defer func() { - assert.NoError(t, containerShutdown(ctx)) + test.NoError(t, containerShutdown(ctx)) }() c, err := NewRedisCache[example](cfg, 0, nil, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) exampleContent := &example{Name: t.Name()} - assert.NoError(t, c.Set(ctx, exampleKey, exampleContent)) + test.NoError(t, c.Set(ctx, exampleKey, exampleContent)) - assert.NoError(t, c.Delete(ctx, exampleKey)) + test.NoError(t, c.Delete(ctx, exampleKey)) }) } diff --git a/cache/redis/redisclient_mock_test.go b/cache/redis/redisclient_mock_test.go new file mode 100644 index 0000000..005a6d8 --- /dev/null +++ b/cache/redis/redisclient_mock_test.go @@ -0,0 +1,240 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package redis + +import ( + "context" + "sync" + "time" + + "github.com/go-redis/redis/v8" +) + +// Ensure, that redisClientMock does implement redisClient. +// If this is not the case, regenerate this file with moq. +var _ redisClient = &redisClientMock{} + +// redisClientMock is a mock implementation of redisClient. +// +// func TestSomethingThatUsesredisClient(t *testing.T) { +// +// // make and configure a mocked redisClient +// mockedredisClient := &redisClientMock{ +// DelFunc: func(ctx context.Context, keys ...string) *redis.IntCmd { +// panic("mock out the Del method") +// }, +// GetFunc: func(ctx context.Context, key string) *redis.StringCmd { +// panic("mock out the Get method") +// }, +// PingFunc: func(ctx context.Context) *redis.StatusCmd { +// panic("mock out the Ping method") +// }, +// SetFunc: func(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { +// panic("mock out the Set method") +// }, +// } +// +// // use mockedredisClient in code that requires redisClient +// // and then make assertions. +// +// } +type redisClientMock struct { + // DelFunc mocks the Del method. + DelFunc func(ctx context.Context, keys ...string) *redis.IntCmd + + // GetFunc mocks the Get method. + GetFunc func(ctx context.Context, key string) *redis.StringCmd + + // PingFunc mocks the Ping method. + PingFunc func(ctx context.Context) *redis.StatusCmd + + // SetFunc mocks the Set method. + SetFunc func(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd + + // calls tracks calls to the methods. + calls struct { + // Del holds details about calls to the Del method. + Del []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Keys is the keys argument value. + Keys []string + } + // Get holds details about calls to the Get method. + Get []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + } + // Ping holds details about calls to the Ping method. + Ping []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // Set holds details about calls to the Set method. + Set []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + // Value is the value argument value. + Value any + // Expiration is the expiration argument value. + Expiration time.Duration + } + } + lockDel sync.RWMutex + lockGet sync.RWMutex + lockPing sync.RWMutex + lockSet sync.RWMutex +} + +// Del calls DelFunc. +func (mock *redisClientMock) Del(ctx context.Context, keys ...string) *redis.IntCmd { + if mock.DelFunc == nil { + panic("redisClientMock.DelFunc: method is nil but redisClient.Del was just called") + } + callInfo := struct { + Ctx context.Context + Keys []string + }{ + Ctx: ctx, + Keys: keys, + } + mock.lockDel.Lock() + mock.calls.Del = append(mock.calls.Del, callInfo) + mock.lockDel.Unlock() + return mock.DelFunc(ctx, keys...) +} + +// DelCalls gets all the calls that were made to Del. +// Check the length with: +// +// len(mockedredisClient.DelCalls()) +func (mock *redisClientMock) DelCalls() []struct { + Ctx context.Context + Keys []string +} { + var calls []struct { + Ctx context.Context + Keys []string + } + mock.lockDel.RLock() + calls = mock.calls.Del + mock.lockDel.RUnlock() + return calls +} + +// Get calls GetFunc. +func (mock *redisClientMock) Get(ctx context.Context, key string) *redis.StringCmd { + if mock.GetFunc == nil { + panic("redisClientMock.GetFunc: method is nil but redisClient.Get was just called") + } + callInfo := struct { + Ctx context.Context + Key string + }{ + Ctx: ctx, + Key: key, + } + mock.lockGet.Lock() + mock.calls.Get = append(mock.calls.Get, callInfo) + mock.lockGet.Unlock() + return mock.GetFunc(ctx, key) +} + +// GetCalls gets all the calls that were made to Get. +// Check the length with: +// +// len(mockedredisClient.GetCalls()) +func (mock *redisClientMock) GetCalls() []struct { + Ctx context.Context + Key string +} { + var calls []struct { + Ctx context.Context + Key string + } + mock.lockGet.RLock() + calls = mock.calls.Get + mock.lockGet.RUnlock() + return calls +} + +// Ping calls PingFunc. +func (mock *redisClientMock) Ping(ctx context.Context) *redis.StatusCmd { + if mock.PingFunc == nil { + panic("redisClientMock.PingFunc: method is nil but redisClient.Ping was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockPing.Lock() + mock.calls.Ping = append(mock.calls.Ping, callInfo) + mock.lockPing.Unlock() + return mock.PingFunc(ctx) +} + +// PingCalls gets all the calls that were made to Ping. +// Check the length with: +// +// len(mockedredisClient.PingCalls()) +func (mock *redisClientMock) PingCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockPing.RLock() + calls = mock.calls.Ping + mock.lockPing.RUnlock() + return calls +} + +// Set calls SetFunc. +func (mock *redisClientMock) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { + if mock.SetFunc == nil { + panic("redisClientMock.SetFunc: method is nil but redisClient.Set was just called") + } + callInfo := struct { + Ctx context.Context + Key string + Value any + Expiration time.Duration + }{ + Ctx: ctx, + Key: key, + Value: value, + Expiration: expiration, + } + mock.lockSet.Lock() + mock.calls.Set = append(mock.calls.Set, callInfo) + mock.lockSet.Unlock() + return mock.SetFunc(ctx, key, value, expiration) +} + +// SetCalls gets all the calls that were made to Set. +// Check the length with: +// +// len(mockedredisClient.SetCalls()) +func (mock *redisClientMock) SetCalls() []struct { + Ctx context.Context + Key string + Value any + Expiration time.Duration +} { + var calls []struct { + Ctx context.Context + Key string + Value any + Expiration time.Duration + } + mock.lockSet.RLock() + calls = mock.calls.Set + mock.lockSet.RUnlock() + return calls +} diff --git a/cache/redis/unit_test.go b/cache/redis/unit_test.go index 3bf2847..e954292 100644 --- a/cache/redis/unit_test.go +++ b/cache/redis/unit_test.go @@ -9,74 +9,50 @@ import ( "time" "github.com/verygoodsoftwarenotvirus/platform/v5/cache" - mockcircuitbreaking "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/go-redis/redis/v8" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) -type mockRedisClient struct { - mock.Mock -} - -func (m *mockRedisClient) Get(ctx context.Context, key string) *redis.StringCmd { - return m.Called(ctx, key).Get(0).(*redis.StringCmd) -} - -func (m *mockRedisClient) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { - return m.Called(ctx, key, value, expiration).Get(0).(*redis.StatusCmd) -} - -func (m *mockRedisClient) Del(ctx context.Context, keys ...string) *redis.IntCmd { - return m.Called(ctx, keys).Get(0).(*redis.IntCmd) -} - -func (m *mockRedisClient) Ping(ctx context.Context) *redis.StatusCmd { - return m.Called(ctx).Get(0).(*redis.StatusCmd) -} - func gobEncodeExample(t *testing.T, e *example) string { t.Helper() var buf bytes.Buffer - require.NoError(t, gob.NewEncoder(&buf).Encode(e)) + must.NoError(t, gob.NewEncoder(&buf).Encode(e)) return buf.String() } -func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *mockRedisClient, *mockcircuitbreaking.MockCircuitBreaker) { +func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *CircuitBreakerMock) { t.Helper() mp := metrics.NewNoopMetricsProvider() hitCounter, err := mp.NewInt64Counter("test_hits") - require.NoError(t, err) + must.NoError(t, err) missCounter, err := mp.NewInt64Counter("test_misses") - require.NoError(t, err) + must.NoError(t, err) setCounter, err := mp.NewInt64Counter("test_sets") - require.NoError(t, err) + must.NoError(t, err) delCounter, err := mp.NewInt64Counter("test_deletes") - require.NoError(t, err) + must.NoError(t, err) errCounter, err := mp.NewInt64Counter("test_errors") - require.NoError(t, err) + must.NoError(t, err) latencyHist, err := mp.NewFloat64Histogram("test_latency") - require.NoError(t, err) + must.NoError(t, err) - client := &mockRedisClient{} - cb := &mockcircuitbreaking.MockCircuitBreaker{} + client := &redisClientMock{} + cb := &CircuitBreakerMock{} return &redisCacheImpl[example]{ logger: logging.NewNoopLogger(), @@ -93,6 +69,27 @@ func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *mockRedisClient, *m }, client, cb } +// counterResult bundles the values a mocked NewInt64Counter call returns. +type counterResult struct { + counter metrics.Int64Counter + err error +} + +// newCounterProviderMock returns a ProviderMock whose NewInt64Counter implementation +// looks up the result keyed on the counter name. Unknown names fail the test. +func newCounterProviderMock(t *testing.T, results map[string]counterResult) *ProviderMock { + t.Helper() + return &ProviderMock{ + NewInt64CounterFunc: func(metricName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + res, ok := results[metricName] + if !ok { + t.Fatalf("unexpected NewInt64Counter call: %q", metricName) + } + return res.counter, res.err + }, + } +} + func TestConfig_ValidateWithContext(T *testing.T) { T.Parallel() @@ -105,7 +102,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { QueueAddresses: []string{"localhost:6379"}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with empty addresses", func(t *testing.T) { @@ -117,7 +114,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { QueueAddresses: []string{}, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with nil addresses", func(t *testing.T) { @@ -127,21 +124,23 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } func TestNewRedisCache(T *testing.T) { T.Parallel() + okCounter := func() metrics.Int64Counter { return metrics.Int64CounterForTest(T, "x") } + T.Run("with single address", func(t *testing.T) { t.Parallel() cfg := &Config{QueueAddresses: []string{"localhost:6379"}} c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, nil, nil) - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) }) T.Run("with multiple addresses", func(t *testing.T) { @@ -150,8 +149,8 @@ func TestNewRedisCache(T *testing.T) { cfg := &Config{QueueAddresses: []string{"localhost:6379", "localhost:6380"}} c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, nil, nil) - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) }) T.Run("with error creating cache hit counter", func(t *testing.T) { @@ -159,14 +158,14 @@ func TestNewRedisCache(T *testing.T) { cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_cache_hits", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter(), err: errors.New("counter error")}, + }) c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - assert.Error(t, err) - assert.Nil(t, c) - - mock.AssertExpectationsForObjects(t, mp) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating cache miss counter", func(t *testing.T) { @@ -174,15 +173,15 @@ func TestNewRedisCache(T *testing.T) { cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_cache_hits", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_misses", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter()}, + name + "_cache_misses": {counter: okCounter(), err: errors.New("counter error")}, + }) c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - assert.Error(t, err) - assert.Nil(t, c) - - mock.AssertExpectationsForObjects(t, mp) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error creating cache set counter", func(t *testing.T) { @@ -190,16 +189,16 @@ func TestNewRedisCache(T *testing.T) { cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_cache_hits", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_misses", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_sets", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter()}, + name + "_cache_misses": {counter: okCounter()}, + name + "_cache_sets": {counter: okCounter(), err: errors.New("counter error")}, + }) c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - assert.Error(t, err) - assert.Nil(t, c) - - mock.AssertExpectationsForObjects(t, mp) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 3, mp.NewInt64CounterCalls()) }) T.Run("with error creating cache delete counter", func(t *testing.T) { @@ -207,17 +206,17 @@ func TestNewRedisCache(T *testing.T) { cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_cache_hits", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_misses", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_sets", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_deletes", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter()}, + name + "_cache_misses": {counter: okCounter()}, + name + "_cache_sets": {counter: okCounter()}, + name + "_cache_deletes": {counter: okCounter(), err: errors.New("counter error")}, + }) c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - assert.Error(t, err) - assert.Nil(t, c) - - mock.AssertExpectationsForObjects(t, mp) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 4, mp.NewInt64CounterCalls()) }) T.Run("with error creating cache error counter", func(t *testing.T) { @@ -225,18 +224,18 @@ func TestNewRedisCache(T *testing.T) { cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_cache_hits", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_misses", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_sets", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_deletes", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter()}, + name + "_cache_misses": {counter: okCounter()}, + name + "_cache_sets": {counter: okCounter()}, + name + "_cache_deletes": {counter: okCounter()}, + name + "_cache_errors": {counter: okCounter(), err: errors.New("counter error")}, + }) c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - assert.Error(t, err) - assert.Nil(t, c) - - mock.AssertExpectationsForObjects(t, mp) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 5, mp.NewInt64CounterCalls()) }) T.Run("with error creating latency histogram", func(t *testing.T) { @@ -246,21 +245,23 @@ func TestNewRedisCache(T *testing.T) { noopMP := metrics.NewNoopMetricsProvider() h, histErr := noopMP.NewFloat64Histogram("test") - require.NoError(t, histErr) - - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_cache_hits", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_misses", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_sets", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_deletes", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_cache_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewFloat64Histogram", name+"_cache_latency_ms", []metric.Float64HistogramOption(nil)).Return(h, errors.New("histogram error")) + must.NoError(t, histErr) + + mp := &ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(metricName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + test.EqOp(t, name+"_cache_latency_ms", metricName) + return h, errors.New("histogram error") + }, + } c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - assert.Error(t, err) - assert.Nil(t, c) - - mock.AssertExpectationsForObjects(t, mp) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 5, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } @@ -276,18 +277,23 @@ func Test_redisCacheImpl_Get_Unit(T *testing.T) { expected := &example{Name: t.Name()} encoded := gobEncodeExample(t, expected) - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb.CannotProceedFunc = func() bool { return false } + cb.SucceededFunc = func() {} - cmd := redis.NewStringCmd(ctx) - cmd.SetVal(encoded) - client.On("Get", testutils.ContextMatcher, exampleKey).Return(cmd) + client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { + test.EqOp(t, exampleKey, key) + cmd := redis.NewStringCmd(ctx) + cmd.SetVal(encoded) + return cmd + } actual, err := impl.Get(ctx, exampleKey) - assert.NoError(t, err) - assert.Equal(t, expected, actual) + test.NoError(t, err) + test.Eq(t, expected, actual) - mock.AssertExpectationsForObjects(t, client, cb) + test.SliceLen(t, 1, client.GetCalls()) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("when circuit breaker cannot proceed", func(t *testing.T) { @@ -296,13 +302,13 @@ func Test_redisCacheImpl_Get_Unit(T *testing.T) { ctx := t.Context() impl, _, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(true) + cb.CannotProceedFunc = func() bool { return true } actual, err := impl.Get(ctx, exampleKey) - assert.ErrorIs(t, err, cache.ErrNotFound) - assert.Nil(t, actual) + test.ErrorIs(t, err, cache.ErrNotFound) + test.Nil(t, actual) - mock.AssertExpectationsForObjects(t, cb) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with redis error", func(t *testing.T) { @@ -311,18 +317,22 @@ func Test_redisCacheImpl_Get_Unit(T *testing.T) { ctx := t.Context() impl, client, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb.CannotProceedFunc = func() bool { return false } + cb.FailedFunc = func() {} - cmd := redis.NewStringCmd(ctx) - cmd.SetErr(errors.New("redis error")) - client.On("Get", testutils.ContextMatcher, exampleKey).Return(cmd) + client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { + test.EqOp(t, exampleKey, key) + cmd := redis.NewStringCmd(ctx) + cmd.SetErr(errors.New("redis error")) + return cmd + } actual, err := impl.Get(ctx, exampleKey) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) - mock.AssertExpectationsForObjects(t, client, cb) + test.SliceLen(t, 1, client.GetCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) T.Run("with decode error", func(t *testing.T) { @@ -331,17 +341,20 @@ func Test_redisCacheImpl_Get_Unit(T *testing.T) { ctx := t.Context() impl, client, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(false) + cb.CannotProceedFunc = func() bool { return false } - cmd := redis.NewStringCmd(ctx) - cmd.SetVal("not valid gob data") - client.On("Get", testutils.ContextMatcher, exampleKey).Return(cmd) + client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { + test.EqOp(t, exampleKey, key) + cmd := redis.NewStringCmd(ctx) + cmd.SetVal("not valid gob data") + return cmd + } actual, err := impl.Get(ctx, exampleKey) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) - mock.AssertExpectationsForObjects(t, client, cb) + test.SliceLen(t, 1, client.GetCalls()) }) } @@ -354,17 +367,24 @@ func Test_redisCacheImpl_Set_Unit(T *testing.T) { ctx := t.Context() impl, client, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() - - cmd := redis.NewStatusCmd(ctx) - cmd.SetVal("OK") - client.On("Set", testutils.ContextMatcher, exampleKey, mock.AnythingOfType("string"), time.Minute).Return(cmd) + cb.CannotProceedFunc = func() bool { return false } + cb.SucceededFunc = func() {} + + client.SetFunc = func(_ context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { + test.EqOp(t, exampleKey, key) + test.EqOp(t, time.Minute, expiration) + _, isString := value.(string) + test.True(t, isString) + cmd := redis.NewStatusCmd(ctx) + cmd.SetVal("OK") + return cmd + } err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) - assert.NoError(t, err) + test.NoError(t, err) - mock.AssertExpectationsForObjects(t, client, cb) + test.SliceLen(t, 1, client.SetCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("when circuit breaker cannot proceed", func(t *testing.T) { @@ -373,12 +393,12 @@ func Test_redisCacheImpl_Set_Unit(T *testing.T) { ctx := t.Context() impl, _, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(true) + cb.CannotProceedFunc = func() bool { return true } err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) - assert.NoError(t, err) + test.NoError(t, err) - mock.AssertExpectationsForObjects(t, cb) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with redis error", func(t *testing.T) { @@ -387,17 +407,21 @@ func Test_redisCacheImpl_Set_Unit(T *testing.T) { ctx := t.Context() impl, client, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb.CannotProceedFunc = func() bool { return false } + cb.FailedFunc = func() {} - cmd := redis.NewStatusCmd(ctx) - cmd.SetErr(errors.New("redis error")) - client.On("Set", testutils.ContextMatcher, exampleKey, mock.AnythingOfType("string"), time.Minute).Return(cmd) + client.SetFunc = func(_ context.Context, key string, _ any, _ time.Duration) *redis.StatusCmd { + test.EqOp(t, exampleKey, key) + cmd := redis.NewStatusCmd(ctx) + cmd.SetErr(errors.New("redis error")) + return cmd + } err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) - assert.Error(t, err) + test.Error(t, err) - mock.AssertExpectationsForObjects(t, client, cb) + test.SliceLen(t, 1, client.SetCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -410,17 +434,21 @@ func Test_redisCacheImpl_Delete_Unit(T *testing.T) { ctx := t.Context() impl, client, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb.CannotProceedFunc = func() bool { return false } + cb.SucceededFunc = func() {} - cmd := redis.NewIntCmd(ctx) - cmd.SetVal(1) - client.On("Del", testutils.ContextMatcher, []string{exampleKey}).Return(cmd) + client.DelFunc = func(_ context.Context, keys ...string) *redis.IntCmd { + test.Eq(t, []string{exampleKey}, keys) + cmd := redis.NewIntCmd(ctx) + cmd.SetVal(1) + return cmd + } err := impl.Delete(ctx, exampleKey) - assert.NoError(t, err) + test.NoError(t, err) - mock.AssertExpectationsForObjects(t, client, cb) + test.SliceLen(t, 1, client.DelCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("when circuit breaker cannot proceed", func(t *testing.T) { @@ -429,12 +457,12 @@ func Test_redisCacheImpl_Delete_Unit(T *testing.T) { ctx := t.Context() impl, _, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(true) + cb.CannotProceedFunc = func() bool { return true } err := impl.Delete(ctx, exampleKey) - assert.NoError(t, err) + test.NoError(t, err) - mock.AssertExpectationsForObjects(t, cb) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with redis error", func(t *testing.T) { @@ -443,17 +471,20 @@ func Test_redisCacheImpl_Delete_Unit(T *testing.T) { ctx := t.Context() impl, client, cb := buildTestImpl(t) - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb.CannotProceedFunc = func() bool { return false } + cb.FailedFunc = func() {} - cmd := redis.NewIntCmd(ctx) - cmd.SetErr(errors.New("redis error")) - client.On("Del", testutils.ContextMatcher, []string{exampleKey}).Return(cmd) + client.DelFunc = func(_ context.Context, _ ...string) *redis.IntCmd { + cmd := redis.NewIntCmd(ctx) + cmd.SetErr(errors.New("redis error")) + return cmd + } err := impl.Delete(ctx, exampleKey) - assert.Error(t, err) + test.Error(t, err) - mock.AssertExpectationsForObjects(t, client, cb) + test.SliceLen(t, 1, client.DelCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -466,13 +497,14 @@ func Test_redisCacheImpl_Ping_Unit(T *testing.T) { ctx := t.Context() impl, client, _ := buildTestImpl(t) - cmd := redis.NewStatusCmd(ctx) - cmd.SetVal("PONG") - client.On("Ping", testutils.ContextMatcher).Return(cmd) - - assert.NoError(t, impl.Ping(ctx)) + client.PingFunc = func(_ context.Context) *redis.StatusCmd { + cmd := redis.NewStatusCmd(ctx) + cmd.SetVal("PONG") + return cmd + } - mock.AssertExpectationsForObjects(t, client) + test.NoError(t, impl.Ping(ctx)) + test.SliceLen(t, 1, client.PingCalls()) }) T.Run("with error", func(t *testing.T) { @@ -481,13 +513,14 @@ func Test_redisCacheImpl_Ping_Unit(T *testing.T) { ctx := t.Context() impl, client, _ := buildTestImpl(t) - cmd := redis.NewStatusCmd(ctx) - cmd.SetErr(errors.New("connection refused")) - client.On("Ping", testutils.ContextMatcher).Return(cmd) - - assert.Error(t, impl.Ping(ctx)) + client.PingFunc = func(_ context.Context) *redis.StatusCmd { + cmd := redis.NewStatusCmd(ctx) + cmd.SetErr(errors.New("connection refused")) + return cmd + } - mock.AssertExpectationsForObjects(t, client) + test.Error(t, impl.Ping(ctx)) + test.SliceLen(t, 1, client.PingCalls()) }) } @@ -504,7 +537,7 @@ func Test_buildRedisClient(T *testing.T) { } c := buildRedisClient(cfg) - assert.NotNil(t, c) + test.NotNil(t, c) }) T.Run("with multiple addresses", func(t *testing.T) { @@ -517,7 +550,7 @@ func Test_buildRedisClient(T *testing.T) { } c := buildRedisClient(cfg) - assert.NotNil(t, c) + test.NotNil(t, c) }) T.Run("with no addresses", func(t *testing.T) { @@ -528,6 +561,6 @@ func Test_buildRedisClient(T *testing.T) { } c := buildRedisClient(cfg) - assert.Nil(t, c) + test.Nil(t, c) }) } diff --git a/go.mod b/go.mod index 51eab5b..a9f6108 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/aws/aws-sdk-go-v2 v1.41.5 github.com/aws/aws-sdk-go-v2/config v1.32.12 github.com/aws/aws-sdk-go-v2/service/s3 v1.97.1 + github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.2 github.com/aws/aws-sdk-go-v2/service/sqs v1.42.24 github.com/aws/aws-sdk-go-v2/service/ssm v1.68.3 github.com/boombuler/barcode v1.1.0 @@ -58,8 +59,10 @@ require ( github.com/samber/do/v2 v2.0.0 github.com/samber/slog-multi v1.7.1 github.com/segmentio/analytics-go/v3 v3.3.0 + github.com/segmentio/kafka-go v0.4.50 github.com/sendgrid/rest v2.6.9+incompatible github.com/sendgrid/sendgrid-go v3.16.1+incompatible + github.com/shoenig/test v1.12.2 github.com/sideshow/apns2 v0.25.0 github.com/stretchr/testify v1.11.1 github.com/stripe/stripe-go/v75 v75.11.0 @@ -102,7 +105,6 @@ require ( require ( filippo.io/edwards25519 v1.1.1 // indirect - github.com/aws/aws-sdk-go-v2/service/sesv2 v1.60.2 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/emicklei/go-restful/v3 v3.12.2 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect @@ -110,12 +112,13 @@ require ( github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.23.0 // indirect github.com/google/gnostic-models v0.7.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/wire v0.7.0 // indirect + github.com/matryer/moq v0.7.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/segmentio/kafka-go v0.4.50 // indirect github.com/x448/float16 v0.8.4 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/term v0.41.0 // indirect @@ -242,10 +245,10 @@ require ( github.com/launchdarkly/ccache v1.1.0 // indirect github.com/launchdarkly/eventsource v1.11.0 // indirect github.com/launchdarkly/go-jsonstream/v3 v3.1.1 // indirect - github.com/launchdarkly/go-sdk-common/v3 v3.5.0 // indirect + github.com/launchdarkly/go-sdk-common/v3 v3.5.0 github.com/launchdarkly/go-sdk-events/v2 v2.0.2 // indirect github.com/launchdarkly/go-semver v1.0.3 // indirect - github.com/launchdarkly/go-server-sdk-evaluation/v2 v2.0.2 // indirect + github.com/launchdarkly/go-server-sdk-evaluation/v2 v2.0.2 github.com/lufia/plan9stats v0.0.0-20260216142805-b3301c5f2a88 // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mailgun/errors v0.5.0 // indirect @@ -329,6 +332,7 @@ require ( tool ( github.com/4meepo/tagalign/cmd/tagalign github.com/daixiang0/gci + github.com/matryer/moq golang.org/x/tools/cmd/goimports golang.org/x/tools/go/analysis/passes/fieldalignment/cmd/fieldalignment ) diff --git a/go.sum b/go.sum index ed6440a..8eef85c 100644 --- a/go.sum +++ b/go.sum @@ -94,8 +94,6 @@ github.com/anthropics/anthropic-sdk-go v1.27.1/go.mod h1:qUKmaW+uuPB64iy1l+4kOSv github.com/aokoli/goutils v1.0.1/go.mod h1:SijmP0QR8LtwsmDs8Yii5Z/S4trXFGFC2oO5g9DP+DQ= github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496 h1:zV3ejI06GQ59hwDQAvmK1qxOQGB3WuVTRoY0okPTAv0= github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= -github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k= -github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.7 h1:3kGOqnh1pPeddVa/E37XNTaWJ8W6vrbYV9lJEkCnhuY= @@ -108,18 +106,12 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqb github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.22.8 h1:nuc44j+otOY0d1e+CWwB6zul57d2YEGlgCyiq3SL0lI= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.22.8/go.mod h1:qSFgGCN8fjdhvlLhTPZdWRWXbwfeZZWF2FEaIplYPhE= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21 h1:SwGMTMLIlvDNyhMteQ6r8IJSBPlRdXX5d4idhIGbkXA= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.21/go.mod h1:UUxgWxofmOdAMuqEsSppbDtGKLfR04HGsD0HXzvhI1k= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 h1:rWyie/PxDRIdhNf4DzRk0lvjVOqFJuNnO8WwaIRVxzQ= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22/go.mod h1:zd/JsJ4P7oGfUhXn1VyLqaRZwPmZwg44Jf2dS84Dm3Y= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= @@ -494,6 +486,8 @@ github.com/mailru/easyjson v0.9.2 h1:dX8U45hQsZpxd80nLvDGihsQ/OxlvTkVUXH2r/8cb2M github.com/mailru/easyjson v0.9.2/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/matcornic/hermes/v2 v2.1.0 h1:9TDYFBPFv6mcXanaDmRDEp/RTWj0dTTi+LpFnnnfNWc= github.com/matcornic/hermes/v2 v2.1.0/go.mod h1:2+ziJeoyRfaLiATIL8VZ7f9hpzH4oDHqTmn0bhrsgVI= +github.com/matryer/moq v0.7.1 h1:/QaXqMAdOrLqlshW2z7SMS21jDi7aVrbW0wJrR+hhJk= +github.com/matryer/moq v0.7.1/go.mod h1:IabIiFkaKCyHxej25INgFR+fnOxSZFMv2LYrU+ioyDs= github.com/mattn/go-colorable v0.0.10-0.20170816031813-ad5389df28cd/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -646,6 +640,8 @@ github.com/sendgrid/sendgrid-go v3.16.1+incompatible h1:zWhTmB0Y8XCDzeWIm2/BIt1G github.com/sendgrid/sendgrid-go v3.16.1+incompatible/go.mod h1:QRQt+LX/NmgVEvmdRw0VT/QgUn499+iza2FnDca9fg8= github.com/shirou/gopsutil/v4 v4.26.2 h1:X8i6sicvUFih4BmYIGT1m2wwgw2VG9YgrDTi7cIRGUI= github.com/shirou/gopsutil/v4 v4.26.2/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ= +github.com/shoenig/test v1.12.2 h1:ZVT8NeIUwGWpZcKaepPmFMoNQ3sVpxvqUh/MAqwFiJI= +github.com/shoenig/test v1.12.2/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sideshow/apns2 v0.25.0 h1:XOzanncO9MQxkb03T/2uU2KcdVjYiIf0TMLzec0FTW4= github.com/sideshow/apns2 v0.25.0/go.mod h1:7Fceu+sL0XscxrfLSkAoH6UtvKefq3Kq1n4W3ayQZqE= @@ -725,6 +721,12 @@ github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0 h1:3UeQBvD0TFrlV github.com/wsxiaoys/terminal v0.0.0-20160513160801-0940f3fc43a0/go.mod h1:IXCdmsXIht47RaVFLEdVnh1t+pgYtTAhQGj73kz+2DM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= From e0ef5b4454be2e2df0616236b627380a9fab0291 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:27:39 -0500 Subject: [PATCH 02/12] test: minor refactor --- cache/config/config_test.go | 3 +- cache/config/mock/doc.go | 7 ++ .../metricsprovider_mock.go} | 72 +++++++++---------- cache/config/mocks_gen_test.go | 5 -- .../circuitbreaker_mock.go} | 2 +- cache/redis/mock/doc.go | 10 +++ .../metricsprovider_mock.go} | 72 +++++++++---------- .../redisclient_mock.go} | 40 +++++------ cache/redis/mocks_gen_test.go | 7 -- cache/redis/unit_test.go | 18 ++--- 10 files changed, 120 insertions(+), 116 deletions(-) create mode 100644 cache/config/mock/doc.go rename cache/config/{metricsprovider_mock_test.go => mock/metricsprovider_mock.go} (81%) delete mode 100644 cache/config/mocks_gen_test.go rename cache/redis/{circuitbreaker_mock_test.go => mock/circuitbreaker_mock.go} (99%) create mode 100644 cache/redis/mock/doc.go rename cache/redis/{metricsprovider_mock_test.go => mock/metricsprovider_mock.go} (81%) rename cache/redis/{redisclient_mock_test.go => mock/redisclient_mock.go} (81%) delete mode 100644 cache/redis/mocks_gen_test.go diff --git a/cache/config/config_test.go b/cache/config/config_test.go index 0c7668a..be0192d 100644 --- a/cache/config/config_test.go +++ b/cache/config/config_test.go @@ -4,6 +4,7 @@ import ( "errors" "testing" + "github.com/verygoodsoftwarenotvirus/platform/v5/cache/config/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/cache/redis" circuitbreakingcfg "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/config" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" @@ -128,7 +129,7 @@ func TestProvideCache(T *testing.T) { }, } - mp := &ProviderMock{ + mp := &mock.MetricsProviderMock{ NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { test.EqOp(t, "redis-cache-breaker_circuit_breaker_tripped", name) return nil, errors.New("counter init failure") diff --git a/cache/config/mock/doc.go b/cache/config/mock/doc.go new file mode 100644 index 0000000..3dd676f --- /dev/null +++ b/cache/config/mock/doc.go @@ -0,0 +1,7 @@ +// Package mock provides moq-generated mock implementations of the interfaces +// that cache/config depends on for unit testing. +package mock + +// Regenerate via `go generate ./cache/config/mock/`. + +//go:generate go tool github.com/matryer/moq -out metricsprovider_mock.go -pkg mock -rm -fmt goimports ../../../observability/metrics Provider:MetricsProviderMock diff --git a/cache/config/metricsprovider_mock_test.go b/cache/config/mock/metricsprovider_mock.go similarity index 81% rename from cache/config/metricsprovider_mock_test.go rename to cache/config/mock/metricsprovider_mock.go index c225860..b739ea9 100644 --- a/cache/config/metricsprovider_mock_test.go +++ b/cache/config/mock/metricsprovider_mock.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package config +package mock import ( "context" @@ -11,16 +11,16 @@ import ( "go.opentelemetry.io/otel/metric" ) -// Ensure, that ProviderMock does implement metrics.Provider. +// Ensure, that MetricsProviderMock does implement metrics.Provider. // If this is not the case, regenerate this file with moq. -var _ metrics.Provider = &ProviderMock{} +var _ metrics.Provider = &MetricsProviderMock{} -// ProviderMock is a mock implementation of metrics.Provider. +// MetricsProviderMock is a mock implementation of metrics.Provider. // // func TestSomethingThatUsesProvider(t *testing.T) { // // // make and configure a mocked metrics.Provider -// mockedProvider := &ProviderMock{ +// mockedProvider := &MetricsProviderMock{ // MeterProviderFunc: func() metric.MeterProvider { // panic("mock out the MeterProvider method") // }, @@ -57,7 +57,7 @@ var _ metrics.Provider = &ProviderMock{} // // and then make assertions. // // } -type ProviderMock struct { +type MetricsProviderMock struct { // MeterProviderFunc mocks the MeterProvider method. MeterProviderFunc func() metric.MeterProvider @@ -168,9 +168,9 @@ type ProviderMock struct { } // MeterProvider calls MeterProviderFunc. -func (mock *ProviderMock) MeterProvider() metric.MeterProvider { +func (mock *MetricsProviderMock) MeterProvider() metric.MeterProvider { if mock.MeterProviderFunc == nil { - panic("ProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") + panic("MetricsProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") } callInfo := struct { }{} @@ -184,7 +184,7 @@ func (mock *ProviderMock) MeterProvider() metric.MeterProvider { // Check the length with: // // len(mockedProvider.MeterProviderCalls()) -func (mock *ProviderMock) MeterProviderCalls() []struct { +func (mock *MetricsProviderMock) MeterProviderCalls() []struct { } { var calls []struct { } @@ -195,9 +195,9 @@ func (mock *ProviderMock) MeterProviderCalls() []struct { } // NewFloat64Counter calls NewFloat64CounterFunc. -func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { +func (mock *MetricsProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { if mock.NewFloat64CounterFunc == nil { - panic("ProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") + panic("MetricsProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") } callInfo := struct { Name string @@ -216,7 +216,7 @@ func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float // Check the length with: // // len(mockedProvider.NewFloat64CounterCalls()) -func (mock *ProviderMock) NewFloat64CounterCalls() []struct { +func (mock *MetricsProviderMock) NewFloat64CounterCalls() []struct { Name string Options []metric.Float64CounterOption } { @@ -231,9 +231,9 @@ func (mock *ProviderMock) NewFloat64CounterCalls() []struct { } // NewFloat64Gauge calls NewFloat64GaugeFunc. -func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { +func (mock *MetricsProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { if mock.NewFloat64GaugeFunc == nil { - panic("ProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") + panic("MetricsProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") } callInfo := struct { Name string @@ -252,7 +252,7 @@ func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64 // Check the length with: // // len(mockedProvider.NewFloat64GaugeCalls()) -func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { +func (mock *MetricsProviderMock) NewFloat64GaugeCalls() []struct { Name string Options []metric.Float64GaugeOption } { @@ -267,9 +267,9 @@ func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { } // NewFloat64Histogram calls NewFloat64HistogramFunc. -func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { +func (mock *MetricsProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { if mock.NewFloat64HistogramFunc == nil { - panic("ProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") + panic("MetricsProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") } callInfo := struct { Name string @@ -288,7 +288,7 @@ func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Flo // Check the length with: // // len(mockedProvider.NewFloat64HistogramCalls()) -func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { +func (mock *MetricsProviderMock) NewFloat64HistogramCalls() []struct { Name string Options []metric.Float64HistogramOption } { @@ -303,9 +303,9 @@ func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { } // NewFloat64UpDownCounter calls NewFloat64UpDownCounterFunc. -func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { +func (mock *MetricsProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { if mock.NewFloat64UpDownCounterFunc == nil { - panic("ProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") + panic("MetricsProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") } callInfo := struct { Name string @@ -324,7 +324,7 @@ func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric // Check the length with: // // len(mockedProvider.NewFloat64UpDownCounterCalls()) -func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { +func (mock *MetricsProviderMock) NewFloat64UpDownCounterCalls() []struct { Name string Options []metric.Float64UpDownCounterOption } { @@ -339,9 +339,9 @@ func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { } // NewInt64Counter calls NewInt64CounterFunc. -func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { +func (mock *MetricsProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { if mock.NewInt64CounterFunc == nil { - panic("ProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") + panic("MetricsProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") } callInfo := struct { Name string @@ -360,7 +360,7 @@ func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64Co // Check the length with: // // len(mockedProvider.NewInt64CounterCalls()) -func (mock *ProviderMock) NewInt64CounterCalls() []struct { +func (mock *MetricsProviderMock) NewInt64CounterCalls() []struct { Name string Options []metric.Int64CounterOption } { @@ -375,9 +375,9 @@ func (mock *ProviderMock) NewInt64CounterCalls() []struct { } // NewInt64Gauge calls NewInt64GaugeFunc. -func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { +func (mock *MetricsProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { if mock.NewInt64GaugeFunc == nil { - panic("ProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") + panic("MetricsProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") } callInfo := struct { Name string @@ -396,7 +396,7 @@ func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64Gaug // Check the length with: // // len(mockedProvider.NewInt64GaugeCalls()) -func (mock *ProviderMock) NewInt64GaugeCalls() []struct { +func (mock *MetricsProviderMock) NewInt64GaugeCalls() []struct { Name string Options []metric.Int64GaugeOption } { @@ -411,9 +411,9 @@ func (mock *ProviderMock) NewInt64GaugeCalls() []struct { } // NewInt64Histogram calls NewInt64HistogramFunc. -func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { +func (mock *MetricsProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { if mock.NewInt64HistogramFunc == nil { - panic("ProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") + panic("MetricsProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") } callInfo := struct { Name string @@ -432,7 +432,7 @@ func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64 // Check the length with: // // len(mockedProvider.NewInt64HistogramCalls()) -func (mock *ProviderMock) NewInt64HistogramCalls() []struct { +func (mock *MetricsProviderMock) NewInt64HistogramCalls() []struct { Name string Options []metric.Int64HistogramOption } { @@ -447,9 +447,9 @@ func (mock *ProviderMock) NewInt64HistogramCalls() []struct { } // NewInt64UpDownCounter calls NewInt64UpDownCounterFunc. -func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { +func (mock *MetricsProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { if mock.NewInt64UpDownCounterFunc == nil { - panic("ProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") + panic("MetricsProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") } callInfo := struct { Name string @@ -468,7 +468,7 @@ func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.I // Check the length with: // // len(mockedProvider.NewInt64UpDownCounterCalls()) -func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { +func (mock *MetricsProviderMock) NewInt64UpDownCounterCalls() []struct { Name string Options []metric.Int64UpDownCounterOption } { @@ -483,9 +483,9 @@ func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { } // Shutdown calls ShutdownFunc. -func (mock *ProviderMock) Shutdown(ctx context.Context) error { +func (mock *MetricsProviderMock) Shutdown(ctx context.Context) error { if mock.ShutdownFunc == nil { - panic("ProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") + panic("MetricsProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") } callInfo := struct { Ctx context.Context @@ -502,7 +502,7 @@ func (mock *ProviderMock) Shutdown(ctx context.Context) error { // Check the length with: // // len(mockedProvider.ShutdownCalls()) -func (mock *ProviderMock) ShutdownCalls() []struct { +func (mock *MetricsProviderMock) ShutdownCalls() []struct { Ctx context.Context } { var calls []struct { diff --git a/cache/config/mocks_gen_test.go b/cache/config/mocks_gen_test.go deleted file mode 100644 index 7ad25d3..0000000 --- a/cache/config/mocks_gen_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package config - -// Regenerate mocks via `go generate ./cache/config/...`. - -//go:generate go tool github.com/matryer/moq -out metricsprovider_mock_test.go -pkg config -rm -fmt goimports ../../observability/metrics Provider diff --git a/cache/redis/circuitbreaker_mock_test.go b/cache/redis/mock/circuitbreaker_mock.go similarity index 99% rename from cache/redis/circuitbreaker_mock_test.go rename to cache/redis/mock/circuitbreaker_mock.go index 8b474ee..eeffec8 100644 --- a/cache/redis/circuitbreaker_mock_test.go +++ b/cache/redis/mock/circuitbreaker_mock.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package redis +package mock import ( "sync" diff --git a/cache/redis/mock/doc.go b/cache/redis/mock/doc.go new file mode 100644 index 0000000..c4ebd1b --- /dev/null +++ b/cache/redis/mock/doc.go @@ -0,0 +1,10 @@ +// Package mock provides moq-generated mock implementations of the interfaces +// that cache/redis depends on for unit testing: the internal redisClient test +// seam plus metrics.Provider and circuitbreaking.CircuitBreaker. +package mock + +// Regenerate via `go generate ./cache/redis/mock/`. + +//go:generate go tool github.com/matryer/moq -out redisclient_mock.go -pkg mock -rm -skip-ensure -fmt goimports .. redisClient:RedisClientMock +//go:generate go tool github.com/matryer/moq -out metricsprovider_mock.go -pkg mock -rm -fmt goimports ../../../observability/metrics Provider:MetricsProviderMock +//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock.go -pkg mock -rm -fmt goimports ../../../circuitbreaking CircuitBreaker:CircuitBreakerMock diff --git a/cache/redis/metricsprovider_mock_test.go b/cache/redis/mock/metricsprovider_mock.go similarity index 81% rename from cache/redis/metricsprovider_mock_test.go rename to cache/redis/mock/metricsprovider_mock.go index 247d6e8..b739ea9 100644 --- a/cache/redis/metricsprovider_mock_test.go +++ b/cache/redis/mock/metricsprovider_mock.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package redis +package mock import ( "context" @@ -11,16 +11,16 @@ import ( "go.opentelemetry.io/otel/metric" ) -// Ensure, that ProviderMock does implement metrics.Provider. +// Ensure, that MetricsProviderMock does implement metrics.Provider. // If this is not the case, regenerate this file with moq. -var _ metrics.Provider = &ProviderMock{} +var _ metrics.Provider = &MetricsProviderMock{} -// ProviderMock is a mock implementation of metrics.Provider. +// MetricsProviderMock is a mock implementation of metrics.Provider. // // func TestSomethingThatUsesProvider(t *testing.T) { // // // make and configure a mocked metrics.Provider -// mockedProvider := &ProviderMock{ +// mockedProvider := &MetricsProviderMock{ // MeterProviderFunc: func() metric.MeterProvider { // panic("mock out the MeterProvider method") // }, @@ -57,7 +57,7 @@ var _ metrics.Provider = &ProviderMock{} // // and then make assertions. // // } -type ProviderMock struct { +type MetricsProviderMock struct { // MeterProviderFunc mocks the MeterProvider method. MeterProviderFunc func() metric.MeterProvider @@ -168,9 +168,9 @@ type ProviderMock struct { } // MeterProvider calls MeterProviderFunc. -func (mock *ProviderMock) MeterProvider() metric.MeterProvider { +func (mock *MetricsProviderMock) MeterProvider() metric.MeterProvider { if mock.MeterProviderFunc == nil { - panic("ProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") + panic("MetricsProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") } callInfo := struct { }{} @@ -184,7 +184,7 @@ func (mock *ProviderMock) MeterProvider() metric.MeterProvider { // Check the length with: // // len(mockedProvider.MeterProviderCalls()) -func (mock *ProviderMock) MeterProviderCalls() []struct { +func (mock *MetricsProviderMock) MeterProviderCalls() []struct { } { var calls []struct { } @@ -195,9 +195,9 @@ func (mock *ProviderMock) MeterProviderCalls() []struct { } // NewFloat64Counter calls NewFloat64CounterFunc. -func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { +func (mock *MetricsProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { if mock.NewFloat64CounterFunc == nil { - panic("ProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") + panic("MetricsProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") } callInfo := struct { Name string @@ -216,7 +216,7 @@ func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float // Check the length with: // // len(mockedProvider.NewFloat64CounterCalls()) -func (mock *ProviderMock) NewFloat64CounterCalls() []struct { +func (mock *MetricsProviderMock) NewFloat64CounterCalls() []struct { Name string Options []metric.Float64CounterOption } { @@ -231,9 +231,9 @@ func (mock *ProviderMock) NewFloat64CounterCalls() []struct { } // NewFloat64Gauge calls NewFloat64GaugeFunc. -func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { +func (mock *MetricsProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { if mock.NewFloat64GaugeFunc == nil { - panic("ProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") + panic("MetricsProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") } callInfo := struct { Name string @@ -252,7 +252,7 @@ func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64 // Check the length with: // // len(mockedProvider.NewFloat64GaugeCalls()) -func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { +func (mock *MetricsProviderMock) NewFloat64GaugeCalls() []struct { Name string Options []metric.Float64GaugeOption } { @@ -267,9 +267,9 @@ func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { } // NewFloat64Histogram calls NewFloat64HistogramFunc. -func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { +func (mock *MetricsProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { if mock.NewFloat64HistogramFunc == nil { - panic("ProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") + panic("MetricsProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") } callInfo := struct { Name string @@ -288,7 +288,7 @@ func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Flo // Check the length with: // // len(mockedProvider.NewFloat64HistogramCalls()) -func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { +func (mock *MetricsProviderMock) NewFloat64HistogramCalls() []struct { Name string Options []metric.Float64HistogramOption } { @@ -303,9 +303,9 @@ func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { } // NewFloat64UpDownCounter calls NewFloat64UpDownCounterFunc. -func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { +func (mock *MetricsProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { if mock.NewFloat64UpDownCounterFunc == nil { - panic("ProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") + panic("MetricsProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") } callInfo := struct { Name string @@ -324,7 +324,7 @@ func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric // Check the length with: // // len(mockedProvider.NewFloat64UpDownCounterCalls()) -func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { +func (mock *MetricsProviderMock) NewFloat64UpDownCounterCalls() []struct { Name string Options []metric.Float64UpDownCounterOption } { @@ -339,9 +339,9 @@ func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { } // NewInt64Counter calls NewInt64CounterFunc. -func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { +func (mock *MetricsProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { if mock.NewInt64CounterFunc == nil { - panic("ProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") + panic("MetricsProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") } callInfo := struct { Name string @@ -360,7 +360,7 @@ func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64Co // Check the length with: // // len(mockedProvider.NewInt64CounterCalls()) -func (mock *ProviderMock) NewInt64CounterCalls() []struct { +func (mock *MetricsProviderMock) NewInt64CounterCalls() []struct { Name string Options []metric.Int64CounterOption } { @@ -375,9 +375,9 @@ func (mock *ProviderMock) NewInt64CounterCalls() []struct { } // NewInt64Gauge calls NewInt64GaugeFunc. -func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { +func (mock *MetricsProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { if mock.NewInt64GaugeFunc == nil { - panic("ProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") + panic("MetricsProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") } callInfo := struct { Name string @@ -396,7 +396,7 @@ func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64Gaug // Check the length with: // // len(mockedProvider.NewInt64GaugeCalls()) -func (mock *ProviderMock) NewInt64GaugeCalls() []struct { +func (mock *MetricsProviderMock) NewInt64GaugeCalls() []struct { Name string Options []metric.Int64GaugeOption } { @@ -411,9 +411,9 @@ func (mock *ProviderMock) NewInt64GaugeCalls() []struct { } // NewInt64Histogram calls NewInt64HistogramFunc. -func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { +func (mock *MetricsProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { if mock.NewInt64HistogramFunc == nil { - panic("ProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") + panic("MetricsProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") } callInfo := struct { Name string @@ -432,7 +432,7 @@ func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64 // Check the length with: // // len(mockedProvider.NewInt64HistogramCalls()) -func (mock *ProviderMock) NewInt64HistogramCalls() []struct { +func (mock *MetricsProviderMock) NewInt64HistogramCalls() []struct { Name string Options []metric.Int64HistogramOption } { @@ -447,9 +447,9 @@ func (mock *ProviderMock) NewInt64HistogramCalls() []struct { } // NewInt64UpDownCounter calls NewInt64UpDownCounterFunc. -func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { +func (mock *MetricsProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { if mock.NewInt64UpDownCounterFunc == nil { - panic("ProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") + panic("MetricsProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") } callInfo := struct { Name string @@ -468,7 +468,7 @@ func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.I // Check the length with: // // len(mockedProvider.NewInt64UpDownCounterCalls()) -func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { +func (mock *MetricsProviderMock) NewInt64UpDownCounterCalls() []struct { Name string Options []metric.Int64UpDownCounterOption } { @@ -483,9 +483,9 @@ func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { } // Shutdown calls ShutdownFunc. -func (mock *ProviderMock) Shutdown(ctx context.Context) error { +func (mock *MetricsProviderMock) Shutdown(ctx context.Context) error { if mock.ShutdownFunc == nil { - panic("ProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") + panic("MetricsProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") } callInfo := struct { Ctx context.Context @@ -502,7 +502,7 @@ func (mock *ProviderMock) Shutdown(ctx context.Context) error { // Check the length with: // // len(mockedProvider.ShutdownCalls()) -func (mock *ProviderMock) ShutdownCalls() []struct { +func (mock *MetricsProviderMock) ShutdownCalls() []struct { Ctx context.Context } { var calls []struct { diff --git a/cache/redis/redisclient_mock_test.go b/cache/redis/mock/redisclient_mock.go similarity index 81% rename from cache/redis/redisclient_mock_test.go rename to cache/redis/mock/redisclient_mock.go index 005a6d8..4be8290 100644 --- a/cache/redis/redisclient_mock_test.go +++ b/cache/redis/mock/redisclient_mock.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package redis +package mock import ( "context" @@ -11,16 +11,12 @@ import ( "github.com/go-redis/redis/v8" ) -// Ensure, that redisClientMock does implement redisClient. -// If this is not the case, regenerate this file with moq. -var _ redisClient = &redisClientMock{} - -// redisClientMock is a mock implementation of redisClient. +// RedisClientMock is a mock implementation of redis.redisClient. // // func TestSomethingThatUsesredisClient(t *testing.T) { // -// // make and configure a mocked redisClient -// mockedredisClient := &redisClientMock{ +// // make and configure a mocked redis.redisClient +// mockedredisClient := &RedisClientMock{ // DelFunc: func(ctx context.Context, keys ...string) *redis.IntCmd { // panic("mock out the Del method") // }, @@ -35,11 +31,11 @@ var _ redisClient = &redisClientMock{} // }, // } // -// // use mockedredisClient in code that requires redisClient +// // use mockedredisClient in code that requires redis.redisClient // // and then make assertions. // // } -type redisClientMock struct { +type RedisClientMock struct { // DelFunc mocks the Del method. DelFunc func(ctx context.Context, keys ...string) *redis.IntCmd @@ -92,9 +88,9 @@ type redisClientMock struct { } // Del calls DelFunc. -func (mock *redisClientMock) Del(ctx context.Context, keys ...string) *redis.IntCmd { +func (mock *RedisClientMock) Del(ctx context.Context, keys ...string) *redis.IntCmd { if mock.DelFunc == nil { - panic("redisClientMock.DelFunc: method is nil but redisClient.Del was just called") + panic("RedisClientMock.DelFunc: method is nil but redisClient.Del was just called") } callInfo := struct { Ctx context.Context @@ -113,7 +109,7 @@ func (mock *redisClientMock) Del(ctx context.Context, keys ...string) *redis.Int // Check the length with: // // len(mockedredisClient.DelCalls()) -func (mock *redisClientMock) DelCalls() []struct { +func (mock *RedisClientMock) DelCalls() []struct { Ctx context.Context Keys []string } { @@ -128,9 +124,9 @@ func (mock *redisClientMock) DelCalls() []struct { } // Get calls GetFunc. -func (mock *redisClientMock) Get(ctx context.Context, key string) *redis.StringCmd { +func (mock *RedisClientMock) Get(ctx context.Context, key string) *redis.StringCmd { if mock.GetFunc == nil { - panic("redisClientMock.GetFunc: method is nil but redisClient.Get was just called") + panic("RedisClientMock.GetFunc: method is nil but redisClient.Get was just called") } callInfo := struct { Ctx context.Context @@ -149,7 +145,7 @@ func (mock *redisClientMock) Get(ctx context.Context, key string) *redis.StringC // Check the length with: // // len(mockedredisClient.GetCalls()) -func (mock *redisClientMock) GetCalls() []struct { +func (mock *RedisClientMock) GetCalls() []struct { Ctx context.Context Key string } { @@ -164,9 +160,9 @@ func (mock *redisClientMock) GetCalls() []struct { } // Ping calls PingFunc. -func (mock *redisClientMock) Ping(ctx context.Context) *redis.StatusCmd { +func (mock *RedisClientMock) Ping(ctx context.Context) *redis.StatusCmd { if mock.PingFunc == nil { - panic("redisClientMock.PingFunc: method is nil but redisClient.Ping was just called") + panic("RedisClientMock.PingFunc: method is nil but redisClient.Ping was just called") } callInfo := struct { Ctx context.Context @@ -183,7 +179,7 @@ func (mock *redisClientMock) Ping(ctx context.Context) *redis.StatusCmd { // Check the length with: // // len(mockedredisClient.PingCalls()) -func (mock *redisClientMock) PingCalls() []struct { +func (mock *RedisClientMock) PingCalls() []struct { Ctx context.Context } { var calls []struct { @@ -196,9 +192,9 @@ func (mock *redisClientMock) PingCalls() []struct { } // Set calls SetFunc. -func (mock *redisClientMock) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { +func (mock *RedisClientMock) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { if mock.SetFunc == nil { - panic("redisClientMock.SetFunc: method is nil but redisClient.Set was just called") + panic("RedisClientMock.SetFunc: method is nil but redisClient.Set was just called") } callInfo := struct { Ctx context.Context @@ -221,7 +217,7 @@ func (mock *redisClientMock) Set(ctx context.Context, key string, value any, exp // Check the length with: // // len(mockedredisClient.SetCalls()) -func (mock *redisClientMock) SetCalls() []struct { +func (mock *RedisClientMock) SetCalls() []struct { Ctx context.Context Key string Value any diff --git a/cache/redis/mocks_gen_test.go b/cache/redis/mocks_gen_test.go deleted file mode 100644 index b0a8e9d..0000000 --- a/cache/redis/mocks_gen_test.go +++ /dev/null @@ -1,7 +0,0 @@ -package redis - -// Regenerate mocks via `go generate ./cache/redis/...`. - -//go:generate go tool github.com/matryer/moq -out redisclient_mock_test.go -pkg redis -rm -fmt goimports . redisClient -//go:generate go tool github.com/matryer/moq -out metricsprovider_mock_test.go -pkg redis -rm -fmt goimports ../../observability/metrics Provider -//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock_test.go -pkg redis -rm -fmt goimports ../../circuitbreaking CircuitBreaker diff --git a/cache/redis/unit_test.go b/cache/redis/unit_test.go index e954292..5d2b9fd 100644 --- a/cache/redis/unit_test.go +++ b/cache/redis/unit_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/verygoodsoftwarenotvirus/platform/v5/cache" + "github.com/verygoodsoftwarenotvirus/platform/v5/cache/redis/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" @@ -28,7 +29,7 @@ func gobEncodeExample(t *testing.T, e *example) string { return buf.String() } -func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *CircuitBreakerMock) { +func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *mock.RedisClientMock, *mock.CircuitBreakerMock) { t.Helper() mp := metrics.NewNoopMetricsProvider() @@ -51,8 +52,8 @@ func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *C latencyHist, err := mp.NewFloat64Histogram("test_latency") must.NoError(t, err) - client := &redisClientMock{} - cb := &CircuitBreakerMock{} + client := &mock.RedisClientMock{} + cb := &mock.CircuitBreakerMock{} return &redisCacheImpl[example]{ logger: logging.NewNoopLogger(), @@ -75,11 +76,12 @@ type counterResult struct { err error } -// newCounterProviderMock returns a ProviderMock whose NewInt64Counter implementation -// looks up the result keyed on the counter name. Unknown names fail the test. -func newCounterProviderMock(t *testing.T, results map[string]counterResult) *ProviderMock { +// newCounterProviderMock returns a MetricsProviderMock whose NewInt64Counter +// implementation looks up the result keyed on the counter name. Unknown names +// fail the test. +func newCounterProviderMock(t *testing.T, results map[string]counterResult) *mock.MetricsProviderMock { t.Helper() - return &ProviderMock{ + return &mock.MetricsProviderMock{ NewInt64CounterFunc: func(metricName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { res, ok := results[metricName] if !ok { @@ -247,7 +249,7 @@ func TestNewRedisCache(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") must.NoError(t, histErr) - mp := &ProviderMock{ + mp := &mock.MetricsProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { return metrics.Int64CounterForTest(t, "x"), nil }, From e8d16647cf51506668a5329c4b0eab3837c22b82 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Fri, 10 Apr 2026 17:49:14 -0500 Subject: [PATCH 03/12] Revert "test: minor refactor" This reverts commit e0ef5b4454be2e2df0616236b627380a9fab0291. --- cache/config/config_test.go | 3 +- .../metricsprovider_mock_test.go} | 72 +++++++++---------- cache/config/mock/doc.go | 7 -- cache/config/mocks_gen_test.go | 5 ++ ...er_mock.go => circuitbreaker_mock_test.go} | 2 +- .../metricsprovider_mock_test.go} | 72 +++++++++---------- cache/redis/mock/doc.go | 10 --- cache/redis/mocks_gen_test.go | 7 ++ ...lient_mock.go => redisclient_mock_test.go} | 40 ++++++----- cache/redis/unit_test.go | 18 +++-- 10 files changed, 116 insertions(+), 120 deletions(-) rename cache/{redis/mock/metricsprovider_mock.go => config/metricsprovider_mock_test.go} (81%) delete mode 100644 cache/config/mock/doc.go create mode 100644 cache/config/mocks_gen_test.go rename cache/redis/{mock/circuitbreaker_mock.go => circuitbreaker_mock_test.go} (99%) rename cache/{config/mock/metricsprovider_mock.go => redis/metricsprovider_mock_test.go} (81%) delete mode 100644 cache/redis/mock/doc.go create mode 100644 cache/redis/mocks_gen_test.go rename cache/redis/{mock/redisclient_mock.go => redisclient_mock_test.go} (81%) diff --git a/cache/config/config_test.go b/cache/config/config_test.go index be0192d..0c7668a 100644 --- a/cache/config/config_test.go +++ b/cache/config/config_test.go @@ -4,7 +4,6 @@ import ( "errors" "testing" - "github.com/verygoodsoftwarenotvirus/platform/v5/cache/config/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/cache/redis" circuitbreakingcfg "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/config" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" @@ -129,7 +128,7 @@ func TestProvideCache(T *testing.T) { }, } - mp := &mock.MetricsProviderMock{ + mp := &ProviderMock{ NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { test.EqOp(t, "redis-cache-breaker_circuit_breaker_tripped", name) return nil, errors.New("counter init failure") diff --git a/cache/redis/mock/metricsprovider_mock.go b/cache/config/metricsprovider_mock_test.go similarity index 81% rename from cache/redis/mock/metricsprovider_mock.go rename to cache/config/metricsprovider_mock_test.go index b739ea9..c225860 100644 --- a/cache/redis/mock/metricsprovider_mock.go +++ b/cache/config/metricsprovider_mock_test.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package mock +package config import ( "context" @@ -11,16 +11,16 @@ import ( "go.opentelemetry.io/otel/metric" ) -// Ensure, that MetricsProviderMock does implement metrics.Provider. +// Ensure, that ProviderMock does implement metrics.Provider. // If this is not the case, regenerate this file with moq. -var _ metrics.Provider = &MetricsProviderMock{} +var _ metrics.Provider = &ProviderMock{} -// MetricsProviderMock is a mock implementation of metrics.Provider. +// ProviderMock is a mock implementation of metrics.Provider. // // func TestSomethingThatUsesProvider(t *testing.T) { // // // make and configure a mocked metrics.Provider -// mockedProvider := &MetricsProviderMock{ +// mockedProvider := &ProviderMock{ // MeterProviderFunc: func() metric.MeterProvider { // panic("mock out the MeterProvider method") // }, @@ -57,7 +57,7 @@ var _ metrics.Provider = &MetricsProviderMock{} // // and then make assertions. // // } -type MetricsProviderMock struct { +type ProviderMock struct { // MeterProviderFunc mocks the MeterProvider method. MeterProviderFunc func() metric.MeterProvider @@ -168,9 +168,9 @@ type MetricsProviderMock struct { } // MeterProvider calls MeterProviderFunc. -func (mock *MetricsProviderMock) MeterProvider() metric.MeterProvider { +func (mock *ProviderMock) MeterProvider() metric.MeterProvider { if mock.MeterProviderFunc == nil { - panic("MetricsProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") + panic("ProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") } callInfo := struct { }{} @@ -184,7 +184,7 @@ func (mock *MetricsProviderMock) MeterProvider() metric.MeterProvider { // Check the length with: // // len(mockedProvider.MeterProviderCalls()) -func (mock *MetricsProviderMock) MeterProviderCalls() []struct { +func (mock *ProviderMock) MeterProviderCalls() []struct { } { var calls []struct { } @@ -195,9 +195,9 @@ func (mock *MetricsProviderMock) MeterProviderCalls() []struct { } // NewFloat64Counter calls NewFloat64CounterFunc. -func (mock *MetricsProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { +func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { if mock.NewFloat64CounterFunc == nil { - panic("MetricsProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") + panic("ProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") } callInfo := struct { Name string @@ -216,7 +216,7 @@ func (mock *MetricsProviderMock) NewFloat64Counter(name string, options ...metri // Check the length with: // // len(mockedProvider.NewFloat64CounterCalls()) -func (mock *MetricsProviderMock) NewFloat64CounterCalls() []struct { +func (mock *ProviderMock) NewFloat64CounterCalls() []struct { Name string Options []metric.Float64CounterOption } { @@ -231,9 +231,9 @@ func (mock *MetricsProviderMock) NewFloat64CounterCalls() []struct { } // NewFloat64Gauge calls NewFloat64GaugeFunc. -func (mock *MetricsProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { +func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { if mock.NewFloat64GaugeFunc == nil { - panic("MetricsProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") + panic("ProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") } callInfo := struct { Name string @@ -252,7 +252,7 @@ func (mock *MetricsProviderMock) NewFloat64Gauge(name string, options ...metric. // Check the length with: // // len(mockedProvider.NewFloat64GaugeCalls()) -func (mock *MetricsProviderMock) NewFloat64GaugeCalls() []struct { +func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { Name string Options []metric.Float64GaugeOption } { @@ -267,9 +267,9 @@ func (mock *MetricsProviderMock) NewFloat64GaugeCalls() []struct { } // NewFloat64Histogram calls NewFloat64HistogramFunc. -func (mock *MetricsProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { +func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { if mock.NewFloat64HistogramFunc == nil { - panic("MetricsProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") + panic("ProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") } callInfo := struct { Name string @@ -288,7 +288,7 @@ func (mock *MetricsProviderMock) NewFloat64Histogram(name string, options ...met // Check the length with: // // len(mockedProvider.NewFloat64HistogramCalls()) -func (mock *MetricsProviderMock) NewFloat64HistogramCalls() []struct { +func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { Name string Options []metric.Float64HistogramOption } { @@ -303,9 +303,9 @@ func (mock *MetricsProviderMock) NewFloat64HistogramCalls() []struct { } // NewFloat64UpDownCounter calls NewFloat64UpDownCounterFunc. -func (mock *MetricsProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { +func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { if mock.NewFloat64UpDownCounterFunc == nil { - panic("MetricsProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") + panic("ProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") } callInfo := struct { Name string @@ -324,7 +324,7 @@ func (mock *MetricsProviderMock) NewFloat64UpDownCounter(name string, options .. // Check the length with: // // len(mockedProvider.NewFloat64UpDownCounterCalls()) -func (mock *MetricsProviderMock) NewFloat64UpDownCounterCalls() []struct { +func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { Name string Options []metric.Float64UpDownCounterOption } { @@ -339,9 +339,9 @@ func (mock *MetricsProviderMock) NewFloat64UpDownCounterCalls() []struct { } // NewInt64Counter calls NewInt64CounterFunc. -func (mock *MetricsProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { +func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { if mock.NewInt64CounterFunc == nil { - panic("MetricsProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") + panic("ProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") } callInfo := struct { Name string @@ -360,7 +360,7 @@ func (mock *MetricsProviderMock) NewInt64Counter(name string, options ...metric. // Check the length with: // // len(mockedProvider.NewInt64CounterCalls()) -func (mock *MetricsProviderMock) NewInt64CounterCalls() []struct { +func (mock *ProviderMock) NewInt64CounterCalls() []struct { Name string Options []metric.Int64CounterOption } { @@ -375,9 +375,9 @@ func (mock *MetricsProviderMock) NewInt64CounterCalls() []struct { } // NewInt64Gauge calls NewInt64GaugeFunc. -func (mock *MetricsProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { +func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { if mock.NewInt64GaugeFunc == nil { - panic("MetricsProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") + panic("ProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") } callInfo := struct { Name string @@ -396,7 +396,7 @@ func (mock *MetricsProviderMock) NewInt64Gauge(name string, options ...metric.In // Check the length with: // // len(mockedProvider.NewInt64GaugeCalls()) -func (mock *MetricsProviderMock) NewInt64GaugeCalls() []struct { +func (mock *ProviderMock) NewInt64GaugeCalls() []struct { Name string Options []metric.Int64GaugeOption } { @@ -411,9 +411,9 @@ func (mock *MetricsProviderMock) NewInt64GaugeCalls() []struct { } // NewInt64Histogram calls NewInt64HistogramFunc. -func (mock *MetricsProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { +func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { if mock.NewInt64HistogramFunc == nil { - panic("MetricsProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") + panic("ProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") } callInfo := struct { Name string @@ -432,7 +432,7 @@ func (mock *MetricsProviderMock) NewInt64Histogram(name string, options ...metri // Check the length with: // // len(mockedProvider.NewInt64HistogramCalls()) -func (mock *MetricsProviderMock) NewInt64HistogramCalls() []struct { +func (mock *ProviderMock) NewInt64HistogramCalls() []struct { Name string Options []metric.Int64HistogramOption } { @@ -447,9 +447,9 @@ func (mock *MetricsProviderMock) NewInt64HistogramCalls() []struct { } // NewInt64UpDownCounter calls NewInt64UpDownCounterFunc. -func (mock *MetricsProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { +func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { if mock.NewInt64UpDownCounterFunc == nil { - panic("MetricsProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") + panic("ProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") } callInfo := struct { Name string @@ -468,7 +468,7 @@ func (mock *MetricsProviderMock) NewInt64UpDownCounter(name string, options ...m // Check the length with: // // len(mockedProvider.NewInt64UpDownCounterCalls()) -func (mock *MetricsProviderMock) NewInt64UpDownCounterCalls() []struct { +func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { Name string Options []metric.Int64UpDownCounterOption } { @@ -483,9 +483,9 @@ func (mock *MetricsProviderMock) NewInt64UpDownCounterCalls() []struct { } // Shutdown calls ShutdownFunc. -func (mock *MetricsProviderMock) Shutdown(ctx context.Context) error { +func (mock *ProviderMock) Shutdown(ctx context.Context) error { if mock.ShutdownFunc == nil { - panic("MetricsProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") + panic("ProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") } callInfo := struct { Ctx context.Context @@ -502,7 +502,7 @@ func (mock *MetricsProviderMock) Shutdown(ctx context.Context) error { // Check the length with: // // len(mockedProvider.ShutdownCalls()) -func (mock *MetricsProviderMock) ShutdownCalls() []struct { +func (mock *ProviderMock) ShutdownCalls() []struct { Ctx context.Context } { var calls []struct { diff --git a/cache/config/mock/doc.go b/cache/config/mock/doc.go deleted file mode 100644 index 3dd676f..0000000 --- a/cache/config/mock/doc.go +++ /dev/null @@ -1,7 +0,0 @@ -// Package mock provides moq-generated mock implementations of the interfaces -// that cache/config depends on for unit testing. -package mock - -// Regenerate via `go generate ./cache/config/mock/`. - -//go:generate go tool github.com/matryer/moq -out metricsprovider_mock.go -pkg mock -rm -fmt goimports ../../../observability/metrics Provider:MetricsProviderMock diff --git a/cache/config/mocks_gen_test.go b/cache/config/mocks_gen_test.go new file mode 100644 index 0000000..7ad25d3 --- /dev/null +++ b/cache/config/mocks_gen_test.go @@ -0,0 +1,5 @@ +package config + +// Regenerate mocks via `go generate ./cache/config/...`. + +//go:generate go tool github.com/matryer/moq -out metricsprovider_mock_test.go -pkg config -rm -fmt goimports ../../observability/metrics Provider diff --git a/cache/redis/mock/circuitbreaker_mock.go b/cache/redis/circuitbreaker_mock_test.go similarity index 99% rename from cache/redis/mock/circuitbreaker_mock.go rename to cache/redis/circuitbreaker_mock_test.go index eeffec8..8b474ee 100644 --- a/cache/redis/mock/circuitbreaker_mock.go +++ b/cache/redis/circuitbreaker_mock_test.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package mock +package redis import ( "sync" diff --git a/cache/config/mock/metricsprovider_mock.go b/cache/redis/metricsprovider_mock_test.go similarity index 81% rename from cache/config/mock/metricsprovider_mock.go rename to cache/redis/metricsprovider_mock_test.go index b739ea9..247d6e8 100644 --- a/cache/config/mock/metricsprovider_mock.go +++ b/cache/redis/metricsprovider_mock_test.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package mock +package redis import ( "context" @@ -11,16 +11,16 @@ import ( "go.opentelemetry.io/otel/metric" ) -// Ensure, that MetricsProviderMock does implement metrics.Provider. +// Ensure, that ProviderMock does implement metrics.Provider. // If this is not the case, regenerate this file with moq. -var _ metrics.Provider = &MetricsProviderMock{} +var _ metrics.Provider = &ProviderMock{} -// MetricsProviderMock is a mock implementation of metrics.Provider. +// ProviderMock is a mock implementation of metrics.Provider. // // func TestSomethingThatUsesProvider(t *testing.T) { // // // make and configure a mocked metrics.Provider -// mockedProvider := &MetricsProviderMock{ +// mockedProvider := &ProviderMock{ // MeterProviderFunc: func() metric.MeterProvider { // panic("mock out the MeterProvider method") // }, @@ -57,7 +57,7 @@ var _ metrics.Provider = &MetricsProviderMock{} // // and then make assertions. // // } -type MetricsProviderMock struct { +type ProviderMock struct { // MeterProviderFunc mocks the MeterProvider method. MeterProviderFunc func() metric.MeterProvider @@ -168,9 +168,9 @@ type MetricsProviderMock struct { } // MeterProvider calls MeterProviderFunc. -func (mock *MetricsProviderMock) MeterProvider() metric.MeterProvider { +func (mock *ProviderMock) MeterProvider() metric.MeterProvider { if mock.MeterProviderFunc == nil { - panic("MetricsProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") + panic("ProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") } callInfo := struct { }{} @@ -184,7 +184,7 @@ func (mock *MetricsProviderMock) MeterProvider() metric.MeterProvider { // Check the length with: // // len(mockedProvider.MeterProviderCalls()) -func (mock *MetricsProviderMock) MeterProviderCalls() []struct { +func (mock *ProviderMock) MeterProviderCalls() []struct { } { var calls []struct { } @@ -195,9 +195,9 @@ func (mock *MetricsProviderMock) MeterProviderCalls() []struct { } // NewFloat64Counter calls NewFloat64CounterFunc. -func (mock *MetricsProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { +func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { if mock.NewFloat64CounterFunc == nil { - panic("MetricsProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") + panic("ProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") } callInfo := struct { Name string @@ -216,7 +216,7 @@ func (mock *MetricsProviderMock) NewFloat64Counter(name string, options ...metri // Check the length with: // // len(mockedProvider.NewFloat64CounterCalls()) -func (mock *MetricsProviderMock) NewFloat64CounterCalls() []struct { +func (mock *ProviderMock) NewFloat64CounterCalls() []struct { Name string Options []metric.Float64CounterOption } { @@ -231,9 +231,9 @@ func (mock *MetricsProviderMock) NewFloat64CounterCalls() []struct { } // NewFloat64Gauge calls NewFloat64GaugeFunc. -func (mock *MetricsProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { +func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { if mock.NewFloat64GaugeFunc == nil { - panic("MetricsProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") + panic("ProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") } callInfo := struct { Name string @@ -252,7 +252,7 @@ func (mock *MetricsProviderMock) NewFloat64Gauge(name string, options ...metric. // Check the length with: // // len(mockedProvider.NewFloat64GaugeCalls()) -func (mock *MetricsProviderMock) NewFloat64GaugeCalls() []struct { +func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { Name string Options []metric.Float64GaugeOption } { @@ -267,9 +267,9 @@ func (mock *MetricsProviderMock) NewFloat64GaugeCalls() []struct { } // NewFloat64Histogram calls NewFloat64HistogramFunc. -func (mock *MetricsProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { +func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { if mock.NewFloat64HistogramFunc == nil { - panic("MetricsProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") + panic("ProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") } callInfo := struct { Name string @@ -288,7 +288,7 @@ func (mock *MetricsProviderMock) NewFloat64Histogram(name string, options ...met // Check the length with: // // len(mockedProvider.NewFloat64HistogramCalls()) -func (mock *MetricsProviderMock) NewFloat64HistogramCalls() []struct { +func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { Name string Options []metric.Float64HistogramOption } { @@ -303,9 +303,9 @@ func (mock *MetricsProviderMock) NewFloat64HistogramCalls() []struct { } // NewFloat64UpDownCounter calls NewFloat64UpDownCounterFunc. -func (mock *MetricsProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { +func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { if mock.NewFloat64UpDownCounterFunc == nil { - panic("MetricsProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") + panic("ProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") } callInfo := struct { Name string @@ -324,7 +324,7 @@ func (mock *MetricsProviderMock) NewFloat64UpDownCounter(name string, options .. // Check the length with: // // len(mockedProvider.NewFloat64UpDownCounterCalls()) -func (mock *MetricsProviderMock) NewFloat64UpDownCounterCalls() []struct { +func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { Name string Options []metric.Float64UpDownCounterOption } { @@ -339,9 +339,9 @@ func (mock *MetricsProviderMock) NewFloat64UpDownCounterCalls() []struct { } // NewInt64Counter calls NewInt64CounterFunc. -func (mock *MetricsProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { +func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { if mock.NewInt64CounterFunc == nil { - panic("MetricsProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") + panic("ProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") } callInfo := struct { Name string @@ -360,7 +360,7 @@ func (mock *MetricsProviderMock) NewInt64Counter(name string, options ...metric. // Check the length with: // // len(mockedProvider.NewInt64CounterCalls()) -func (mock *MetricsProviderMock) NewInt64CounterCalls() []struct { +func (mock *ProviderMock) NewInt64CounterCalls() []struct { Name string Options []metric.Int64CounterOption } { @@ -375,9 +375,9 @@ func (mock *MetricsProviderMock) NewInt64CounterCalls() []struct { } // NewInt64Gauge calls NewInt64GaugeFunc. -func (mock *MetricsProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { +func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { if mock.NewInt64GaugeFunc == nil { - panic("MetricsProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") + panic("ProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") } callInfo := struct { Name string @@ -396,7 +396,7 @@ func (mock *MetricsProviderMock) NewInt64Gauge(name string, options ...metric.In // Check the length with: // // len(mockedProvider.NewInt64GaugeCalls()) -func (mock *MetricsProviderMock) NewInt64GaugeCalls() []struct { +func (mock *ProviderMock) NewInt64GaugeCalls() []struct { Name string Options []metric.Int64GaugeOption } { @@ -411,9 +411,9 @@ func (mock *MetricsProviderMock) NewInt64GaugeCalls() []struct { } // NewInt64Histogram calls NewInt64HistogramFunc. -func (mock *MetricsProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { +func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { if mock.NewInt64HistogramFunc == nil { - panic("MetricsProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") + panic("ProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") } callInfo := struct { Name string @@ -432,7 +432,7 @@ func (mock *MetricsProviderMock) NewInt64Histogram(name string, options ...metri // Check the length with: // // len(mockedProvider.NewInt64HistogramCalls()) -func (mock *MetricsProviderMock) NewInt64HistogramCalls() []struct { +func (mock *ProviderMock) NewInt64HistogramCalls() []struct { Name string Options []metric.Int64HistogramOption } { @@ -447,9 +447,9 @@ func (mock *MetricsProviderMock) NewInt64HistogramCalls() []struct { } // NewInt64UpDownCounter calls NewInt64UpDownCounterFunc. -func (mock *MetricsProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { +func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { if mock.NewInt64UpDownCounterFunc == nil { - panic("MetricsProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") + panic("ProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") } callInfo := struct { Name string @@ -468,7 +468,7 @@ func (mock *MetricsProviderMock) NewInt64UpDownCounter(name string, options ...m // Check the length with: // // len(mockedProvider.NewInt64UpDownCounterCalls()) -func (mock *MetricsProviderMock) NewInt64UpDownCounterCalls() []struct { +func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { Name string Options []metric.Int64UpDownCounterOption } { @@ -483,9 +483,9 @@ func (mock *MetricsProviderMock) NewInt64UpDownCounterCalls() []struct { } // Shutdown calls ShutdownFunc. -func (mock *MetricsProviderMock) Shutdown(ctx context.Context) error { +func (mock *ProviderMock) Shutdown(ctx context.Context) error { if mock.ShutdownFunc == nil { - panic("MetricsProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") + panic("ProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") } callInfo := struct { Ctx context.Context @@ -502,7 +502,7 @@ func (mock *MetricsProviderMock) Shutdown(ctx context.Context) error { // Check the length with: // // len(mockedProvider.ShutdownCalls()) -func (mock *MetricsProviderMock) ShutdownCalls() []struct { +func (mock *ProviderMock) ShutdownCalls() []struct { Ctx context.Context } { var calls []struct { diff --git a/cache/redis/mock/doc.go b/cache/redis/mock/doc.go deleted file mode 100644 index c4ebd1b..0000000 --- a/cache/redis/mock/doc.go +++ /dev/null @@ -1,10 +0,0 @@ -// Package mock provides moq-generated mock implementations of the interfaces -// that cache/redis depends on for unit testing: the internal redisClient test -// seam plus metrics.Provider and circuitbreaking.CircuitBreaker. -package mock - -// Regenerate via `go generate ./cache/redis/mock/`. - -//go:generate go tool github.com/matryer/moq -out redisclient_mock.go -pkg mock -rm -skip-ensure -fmt goimports .. redisClient:RedisClientMock -//go:generate go tool github.com/matryer/moq -out metricsprovider_mock.go -pkg mock -rm -fmt goimports ../../../observability/metrics Provider:MetricsProviderMock -//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock.go -pkg mock -rm -fmt goimports ../../../circuitbreaking CircuitBreaker:CircuitBreakerMock diff --git a/cache/redis/mocks_gen_test.go b/cache/redis/mocks_gen_test.go new file mode 100644 index 0000000..b0a8e9d --- /dev/null +++ b/cache/redis/mocks_gen_test.go @@ -0,0 +1,7 @@ +package redis + +// Regenerate mocks via `go generate ./cache/redis/...`. + +//go:generate go tool github.com/matryer/moq -out redisclient_mock_test.go -pkg redis -rm -fmt goimports . redisClient +//go:generate go tool github.com/matryer/moq -out metricsprovider_mock_test.go -pkg redis -rm -fmt goimports ../../observability/metrics Provider +//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock_test.go -pkg redis -rm -fmt goimports ../../circuitbreaking CircuitBreaker diff --git a/cache/redis/mock/redisclient_mock.go b/cache/redis/redisclient_mock_test.go similarity index 81% rename from cache/redis/mock/redisclient_mock.go rename to cache/redis/redisclient_mock_test.go index 4be8290..005a6d8 100644 --- a/cache/redis/mock/redisclient_mock.go +++ b/cache/redis/redisclient_mock_test.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package mock +package redis import ( "context" @@ -11,12 +11,16 @@ import ( "github.com/go-redis/redis/v8" ) -// RedisClientMock is a mock implementation of redis.redisClient. +// Ensure, that redisClientMock does implement redisClient. +// If this is not the case, regenerate this file with moq. +var _ redisClient = &redisClientMock{} + +// redisClientMock is a mock implementation of redisClient. // // func TestSomethingThatUsesredisClient(t *testing.T) { // -// // make and configure a mocked redis.redisClient -// mockedredisClient := &RedisClientMock{ +// // make and configure a mocked redisClient +// mockedredisClient := &redisClientMock{ // DelFunc: func(ctx context.Context, keys ...string) *redis.IntCmd { // panic("mock out the Del method") // }, @@ -31,11 +35,11 @@ import ( // }, // } // -// // use mockedredisClient in code that requires redis.redisClient +// // use mockedredisClient in code that requires redisClient // // and then make assertions. // // } -type RedisClientMock struct { +type redisClientMock struct { // DelFunc mocks the Del method. DelFunc func(ctx context.Context, keys ...string) *redis.IntCmd @@ -88,9 +92,9 @@ type RedisClientMock struct { } // Del calls DelFunc. -func (mock *RedisClientMock) Del(ctx context.Context, keys ...string) *redis.IntCmd { +func (mock *redisClientMock) Del(ctx context.Context, keys ...string) *redis.IntCmd { if mock.DelFunc == nil { - panic("RedisClientMock.DelFunc: method is nil but redisClient.Del was just called") + panic("redisClientMock.DelFunc: method is nil but redisClient.Del was just called") } callInfo := struct { Ctx context.Context @@ -109,7 +113,7 @@ func (mock *RedisClientMock) Del(ctx context.Context, keys ...string) *redis.Int // Check the length with: // // len(mockedredisClient.DelCalls()) -func (mock *RedisClientMock) DelCalls() []struct { +func (mock *redisClientMock) DelCalls() []struct { Ctx context.Context Keys []string } { @@ -124,9 +128,9 @@ func (mock *RedisClientMock) DelCalls() []struct { } // Get calls GetFunc. -func (mock *RedisClientMock) Get(ctx context.Context, key string) *redis.StringCmd { +func (mock *redisClientMock) Get(ctx context.Context, key string) *redis.StringCmd { if mock.GetFunc == nil { - panic("RedisClientMock.GetFunc: method is nil but redisClient.Get was just called") + panic("redisClientMock.GetFunc: method is nil but redisClient.Get was just called") } callInfo := struct { Ctx context.Context @@ -145,7 +149,7 @@ func (mock *RedisClientMock) Get(ctx context.Context, key string) *redis.StringC // Check the length with: // // len(mockedredisClient.GetCalls()) -func (mock *RedisClientMock) GetCalls() []struct { +func (mock *redisClientMock) GetCalls() []struct { Ctx context.Context Key string } { @@ -160,9 +164,9 @@ func (mock *RedisClientMock) GetCalls() []struct { } // Ping calls PingFunc. -func (mock *RedisClientMock) Ping(ctx context.Context) *redis.StatusCmd { +func (mock *redisClientMock) Ping(ctx context.Context) *redis.StatusCmd { if mock.PingFunc == nil { - panic("RedisClientMock.PingFunc: method is nil but redisClient.Ping was just called") + panic("redisClientMock.PingFunc: method is nil but redisClient.Ping was just called") } callInfo := struct { Ctx context.Context @@ -179,7 +183,7 @@ func (mock *RedisClientMock) Ping(ctx context.Context) *redis.StatusCmd { // Check the length with: // // len(mockedredisClient.PingCalls()) -func (mock *RedisClientMock) PingCalls() []struct { +func (mock *redisClientMock) PingCalls() []struct { Ctx context.Context } { var calls []struct { @@ -192,9 +196,9 @@ func (mock *RedisClientMock) PingCalls() []struct { } // Set calls SetFunc. -func (mock *RedisClientMock) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { +func (mock *redisClientMock) Set(ctx context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { if mock.SetFunc == nil { - panic("RedisClientMock.SetFunc: method is nil but redisClient.Set was just called") + panic("redisClientMock.SetFunc: method is nil but redisClient.Set was just called") } callInfo := struct { Ctx context.Context @@ -217,7 +221,7 @@ func (mock *RedisClientMock) Set(ctx context.Context, key string, value any, exp // Check the length with: // // len(mockedredisClient.SetCalls()) -func (mock *RedisClientMock) SetCalls() []struct { +func (mock *redisClientMock) SetCalls() []struct { Ctx context.Context Key string Value any diff --git a/cache/redis/unit_test.go b/cache/redis/unit_test.go index 5d2b9fd..e954292 100644 --- a/cache/redis/unit_test.go +++ b/cache/redis/unit_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/verygoodsoftwarenotvirus/platform/v5/cache" - "github.com/verygoodsoftwarenotvirus/platform/v5/cache/redis/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" @@ -29,7 +28,7 @@ func gobEncodeExample(t *testing.T, e *example) string { return buf.String() } -func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *mock.RedisClientMock, *mock.CircuitBreakerMock) { +func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *CircuitBreakerMock) { t.Helper() mp := metrics.NewNoopMetricsProvider() @@ -52,8 +51,8 @@ func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *mock.RedisClientMoc latencyHist, err := mp.NewFloat64Histogram("test_latency") must.NoError(t, err) - client := &mock.RedisClientMock{} - cb := &mock.CircuitBreakerMock{} + client := &redisClientMock{} + cb := &CircuitBreakerMock{} return &redisCacheImpl[example]{ logger: logging.NewNoopLogger(), @@ -76,12 +75,11 @@ type counterResult struct { err error } -// newCounterProviderMock returns a MetricsProviderMock whose NewInt64Counter -// implementation looks up the result keyed on the counter name. Unknown names -// fail the test. -func newCounterProviderMock(t *testing.T, results map[string]counterResult) *mock.MetricsProviderMock { +// newCounterProviderMock returns a ProviderMock whose NewInt64Counter implementation +// looks up the result keyed on the counter name. Unknown names fail the test. +func newCounterProviderMock(t *testing.T, results map[string]counterResult) *ProviderMock { t.Helper() - return &mock.MetricsProviderMock{ + return &ProviderMock{ NewInt64CounterFunc: func(metricName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { res, ok := results[metricName] if !ok { @@ -249,7 +247,7 @@ func TestNewRedisCache(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") must.NoError(t, histErr) - mp := &mock.MetricsProviderMock{ + mp := &ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { return metrics.Int64CounterForTest(t, "x"), nil }, From 60d2d392dcf0cc3e47f134b8a2580c9d68a71d88 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Fri, 10 Apr 2026 18:10:28 -0500 Subject: [PATCH 04/12] test: another goaround --- cache/config/config_test.go | 3 +- cache/config/metricsprovider_mock_test.go | 515 ------------------ cache/config/mocks_gen_test.go | 5 - cache/redis/mocks_gen_test.go | 9 +- cache/redis/unit_test.go | 17 +- .../mock2/circuitbreaker_mock.go | 2 +- circuitbreaking/mock2/doc.go | 10 + observability/metrics/mock2/doc.go | 10 + .../metrics/mock2/provider_mock.go | 2 +- 9 files changed, 40 insertions(+), 533 deletions(-) delete mode 100644 cache/config/metricsprovider_mock_test.go delete mode 100644 cache/config/mocks_gen_test.go rename cache/redis/circuitbreaker_mock_test.go => circuitbreaking/mock2/circuitbreaker_mock.go (99%) create mode 100644 circuitbreaking/mock2/doc.go create mode 100644 observability/metrics/mock2/doc.go rename cache/redis/metricsprovider_mock_test.go => observability/metrics/mock2/provider_mock.go (99%) diff --git a/cache/config/config_test.go b/cache/config/config_test.go index 0c7668a..4ef6685 100644 --- a/cache/config/config_test.go +++ b/cache/config/config_test.go @@ -8,6 +8,7 @@ import ( circuitbreakingcfg "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/config" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + metricsmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock2" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" @@ -128,7 +129,7 @@ func TestProvideCache(T *testing.T) { }, } - mp := &ProviderMock{ + mp := &metricsmock2.ProviderMock{ NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { test.EqOp(t, "redis-cache-breaker_circuit_breaker_tripped", name) return nil, errors.New("counter init failure") diff --git a/cache/config/metricsprovider_mock_test.go b/cache/config/metricsprovider_mock_test.go deleted file mode 100644 index c225860..0000000 --- a/cache/config/metricsprovider_mock_test.go +++ /dev/null @@ -1,515 +0,0 @@ -// Code generated by moq; DO NOT EDIT. -// github.com/matryer/moq - -package config - -import ( - "context" - "sync" - - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - "go.opentelemetry.io/otel/metric" -) - -// Ensure, that ProviderMock does implement metrics.Provider. -// If this is not the case, regenerate this file with moq. -var _ metrics.Provider = &ProviderMock{} - -// ProviderMock is a mock implementation of metrics.Provider. -// -// func TestSomethingThatUsesProvider(t *testing.T) { -// -// // make and configure a mocked metrics.Provider -// mockedProvider := &ProviderMock{ -// MeterProviderFunc: func() metric.MeterProvider { -// panic("mock out the MeterProvider method") -// }, -// NewFloat64CounterFunc: func(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { -// panic("mock out the NewFloat64Counter method") -// }, -// NewFloat64GaugeFunc: func(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { -// panic("mock out the NewFloat64Gauge method") -// }, -// NewFloat64HistogramFunc: func(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { -// panic("mock out the NewFloat64Histogram method") -// }, -// NewFloat64UpDownCounterFunc: func(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { -// panic("mock out the NewFloat64UpDownCounter method") -// }, -// NewInt64CounterFunc: func(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { -// panic("mock out the NewInt64Counter method") -// }, -// NewInt64GaugeFunc: func(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { -// panic("mock out the NewInt64Gauge method") -// }, -// NewInt64HistogramFunc: func(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { -// panic("mock out the NewInt64Histogram method") -// }, -// NewInt64UpDownCounterFunc: func(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { -// panic("mock out the NewInt64UpDownCounter method") -// }, -// ShutdownFunc: func(ctx context.Context) error { -// panic("mock out the Shutdown method") -// }, -// } -// -// // use mockedProvider in code that requires metrics.Provider -// // and then make assertions. -// -// } -type ProviderMock struct { - // MeterProviderFunc mocks the MeterProvider method. - MeterProviderFunc func() metric.MeterProvider - - // NewFloat64CounterFunc mocks the NewFloat64Counter method. - NewFloat64CounterFunc func(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) - - // NewFloat64GaugeFunc mocks the NewFloat64Gauge method. - NewFloat64GaugeFunc func(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) - - // NewFloat64HistogramFunc mocks the NewFloat64Histogram method. - NewFloat64HistogramFunc func(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) - - // NewFloat64UpDownCounterFunc mocks the NewFloat64UpDownCounter method. - NewFloat64UpDownCounterFunc func(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) - - // NewInt64CounterFunc mocks the NewInt64Counter method. - NewInt64CounterFunc func(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) - - // NewInt64GaugeFunc mocks the NewInt64Gauge method. - NewInt64GaugeFunc func(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) - - // NewInt64HistogramFunc mocks the NewInt64Histogram method. - NewInt64HistogramFunc func(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) - - // NewInt64UpDownCounterFunc mocks the NewInt64UpDownCounter method. - NewInt64UpDownCounterFunc func(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) - - // ShutdownFunc mocks the Shutdown method. - ShutdownFunc func(ctx context.Context) error - - // calls tracks calls to the methods. - calls struct { - // MeterProvider holds details about calls to the MeterProvider method. - MeterProvider []struct { - } - // NewFloat64Counter holds details about calls to the NewFloat64Counter method. - NewFloat64Counter []struct { - // Name is the name argument value. - Name string - // Options is the options argument value. - Options []metric.Float64CounterOption - } - // NewFloat64Gauge holds details about calls to the NewFloat64Gauge method. - NewFloat64Gauge []struct { - // Name is the name argument value. - Name string - // Options is the options argument value. - Options []metric.Float64GaugeOption - } - // NewFloat64Histogram holds details about calls to the NewFloat64Histogram method. - NewFloat64Histogram []struct { - // Name is the name argument value. - Name string - // Options is the options argument value. - Options []metric.Float64HistogramOption - } - // NewFloat64UpDownCounter holds details about calls to the NewFloat64UpDownCounter method. - NewFloat64UpDownCounter []struct { - // Name is the name argument value. - Name string - // Options is the options argument value. - Options []metric.Float64UpDownCounterOption - } - // NewInt64Counter holds details about calls to the NewInt64Counter method. - NewInt64Counter []struct { - // Name is the name argument value. - Name string - // Options is the options argument value. - Options []metric.Int64CounterOption - } - // NewInt64Gauge holds details about calls to the NewInt64Gauge method. - NewInt64Gauge []struct { - // Name is the name argument value. - Name string - // Options is the options argument value. - Options []metric.Int64GaugeOption - } - // NewInt64Histogram holds details about calls to the NewInt64Histogram method. - NewInt64Histogram []struct { - // Name is the name argument value. - Name string - // Options is the options argument value. - Options []metric.Int64HistogramOption - } - // NewInt64UpDownCounter holds details about calls to the NewInt64UpDownCounter method. - NewInt64UpDownCounter []struct { - // Name is the name argument value. - Name string - // Options is the options argument value. - Options []metric.Int64UpDownCounterOption - } - // Shutdown holds details about calls to the Shutdown method. - Shutdown []struct { - // Ctx is the ctx argument value. - Ctx context.Context - } - } - lockMeterProvider sync.RWMutex - lockNewFloat64Counter sync.RWMutex - lockNewFloat64Gauge sync.RWMutex - lockNewFloat64Histogram sync.RWMutex - lockNewFloat64UpDownCounter sync.RWMutex - lockNewInt64Counter sync.RWMutex - lockNewInt64Gauge sync.RWMutex - lockNewInt64Histogram sync.RWMutex - lockNewInt64UpDownCounter sync.RWMutex - lockShutdown sync.RWMutex -} - -// MeterProvider calls MeterProviderFunc. -func (mock *ProviderMock) MeterProvider() metric.MeterProvider { - if mock.MeterProviderFunc == nil { - panic("ProviderMock.MeterProviderFunc: method is nil but Provider.MeterProvider was just called") - } - callInfo := struct { - }{} - mock.lockMeterProvider.Lock() - mock.calls.MeterProvider = append(mock.calls.MeterProvider, callInfo) - mock.lockMeterProvider.Unlock() - return mock.MeterProviderFunc() -} - -// MeterProviderCalls gets all the calls that were made to MeterProvider. -// Check the length with: -// -// len(mockedProvider.MeterProviderCalls()) -func (mock *ProviderMock) MeterProviderCalls() []struct { -} { - var calls []struct { - } - mock.lockMeterProvider.RLock() - calls = mock.calls.MeterProvider - mock.lockMeterProvider.RUnlock() - return calls -} - -// NewFloat64Counter calls NewFloat64CounterFunc. -func (mock *ProviderMock) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { - if mock.NewFloat64CounterFunc == nil { - panic("ProviderMock.NewFloat64CounterFunc: method is nil but Provider.NewFloat64Counter was just called") - } - callInfo := struct { - Name string - Options []metric.Float64CounterOption - }{ - Name: name, - Options: options, - } - mock.lockNewFloat64Counter.Lock() - mock.calls.NewFloat64Counter = append(mock.calls.NewFloat64Counter, callInfo) - mock.lockNewFloat64Counter.Unlock() - return mock.NewFloat64CounterFunc(name, options...) -} - -// NewFloat64CounterCalls gets all the calls that were made to NewFloat64Counter. -// Check the length with: -// -// len(mockedProvider.NewFloat64CounterCalls()) -func (mock *ProviderMock) NewFloat64CounterCalls() []struct { - Name string - Options []metric.Float64CounterOption -} { - var calls []struct { - Name string - Options []metric.Float64CounterOption - } - mock.lockNewFloat64Counter.RLock() - calls = mock.calls.NewFloat64Counter - mock.lockNewFloat64Counter.RUnlock() - return calls -} - -// NewFloat64Gauge calls NewFloat64GaugeFunc. -func (mock *ProviderMock) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { - if mock.NewFloat64GaugeFunc == nil { - panic("ProviderMock.NewFloat64GaugeFunc: method is nil but Provider.NewFloat64Gauge was just called") - } - callInfo := struct { - Name string - Options []metric.Float64GaugeOption - }{ - Name: name, - Options: options, - } - mock.lockNewFloat64Gauge.Lock() - mock.calls.NewFloat64Gauge = append(mock.calls.NewFloat64Gauge, callInfo) - mock.lockNewFloat64Gauge.Unlock() - return mock.NewFloat64GaugeFunc(name, options...) -} - -// NewFloat64GaugeCalls gets all the calls that were made to NewFloat64Gauge. -// Check the length with: -// -// len(mockedProvider.NewFloat64GaugeCalls()) -func (mock *ProviderMock) NewFloat64GaugeCalls() []struct { - Name string - Options []metric.Float64GaugeOption -} { - var calls []struct { - Name string - Options []metric.Float64GaugeOption - } - mock.lockNewFloat64Gauge.RLock() - calls = mock.calls.NewFloat64Gauge - mock.lockNewFloat64Gauge.RUnlock() - return calls -} - -// NewFloat64Histogram calls NewFloat64HistogramFunc. -func (mock *ProviderMock) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { - if mock.NewFloat64HistogramFunc == nil { - panic("ProviderMock.NewFloat64HistogramFunc: method is nil but Provider.NewFloat64Histogram was just called") - } - callInfo := struct { - Name string - Options []metric.Float64HistogramOption - }{ - Name: name, - Options: options, - } - mock.lockNewFloat64Histogram.Lock() - mock.calls.NewFloat64Histogram = append(mock.calls.NewFloat64Histogram, callInfo) - mock.lockNewFloat64Histogram.Unlock() - return mock.NewFloat64HistogramFunc(name, options...) -} - -// NewFloat64HistogramCalls gets all the calls that were made to NewFloat64Histogram. -// Check the length with: -// -// len(mockedProvider.NewFloat64HistogramCalls()) -func (mock *ProviderMock) NewFloat64HistogramCalls() []struct { - Name string - Options []metric.Float64HistogramOption -} { - var calls []struct { - Name string - Options []metric.Float64HistogramOption - } - mock.lockNewFloat64Histogram.RLock() - calls = mock.calls.NewFloat64Histogram - mock.lockNewFloat64Histogram.RUnlock() - return calls -} - -// NewFloat64UpDownCounter calls NewFloat64UpDownCounterFunc. -func (mock *ProviderMock) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { - if mock.NewFloat64UpDownCounterFunc == nil { - panic("ProviderMock.NewFloat64UpDownCounterFunc: method is nil but Provider.NewFloat64UpDownCounter was just called") - } - callInfo := struct { - Name string - Options []metric.Float64UpDownCounterOption - }{ - Name: name, - Options: options, - } - mock.lockNewFloat64UpDownCounter.Lock() - mock.calls.NewFloat64UpDownCounter = append(mock.calls.NewFloat64UpDownCounter, callInfo) - mock.lockNewFloat64UpDownCounter.Unlock() - return mock.NewFloat64UpDownCounterFunc(name, options...) -} - -// NewFloat64UpDownCounterCalls gets all the calls that were made to NewFloat64UpDownCounter. -// Check the length with: -// -// len(mockedProvider.NewFloat64UpDownCounterCalls()) -func (mock *ProviderMock) NewFloat64UpDownCounterCalls() []struct { - Name string - Options []metric.Float64UpDownCounterOption -} { - var calls []struct { - Name string - Options []metric.Float64UpDownCounterOption - } - mock.lockNewFloat64UpDownCounter.RLock() - calls = mock.calls.NewFloat64UpDownCounter - mock.lockNewFloat64UpDownCounter.RUnlock() - return calls -} - -// NewInt64Counter calls NewInt64CounterFunc. -func (mock *ProviderMock) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { - if mock.NewInt64CounterFunc == nil { - panic("ProviderMock.NewInt64CounterFunc: method is nil but Provider.NewInt64Counter was just called") - } - callInfo := struct { - Name string - Options []metric.Int64CounterOption - }{ - Name: name, - Options: options, - } - mock.lockNewInt64Counter.Lock() - mock.calls.NewInt64Counter = append(mock.calls.NewInt64Counter, callInfo) - mock.lockNewInt64Counter.Unlock() - return mock.NewInt64CounterFunc(name, options...) -} - -// NewInt64CounterCalls gets all the calls that were made to NewInt64Counter. -// Check the length with: -// -// len(mockedProvider.NewInt64CounterCalls()) -func (mock *ProviderMock) NewInt64CounterCalls() []struct { - Name string - Options []metric.Int64CounterOption -} { - var calls []struct { - Name string - Options []metric.Int64CounterOption - } - mock.lockNewInt64Counter.RLock() - calls = mock.calls.NewInt64Counter - mock.lockNewInt64Counter.RUnlock() - return calls -} - -// NewInt64Gauge calls NewInt64GaugeFunc. -func (mock *ProviderMock) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { - if mock.NewInt64GaugeFunc == nil { - panic("ProviderMock.NewInt64GaugeFunc: method is nil but Provider.NewInt64Gauge was just called") - } - callInfo := struct { - Name string - Options []metric.Int64GaugeOption - }{ - Name: name, - Options: options, - } - mock.lockNewInt64Gauge.Lock() - mock.calls.NewInt64Gauge = append(mock.calls.NewInt64Gauge, callInfo) - mock.lockNewInt64Gauge.Unlock() - return mock.NewInt64GaugeFunc(name, options...) -} - -// NewInt64GaugeCalls gets all the calls that were made to NewInt64Gauge. -// Check the length with: -// -// len(mockedProvider.NewInt64GaugeCalls()) -func (mock *ProviderMock) NewInt64GaugeCalls() []struct { - Name string - Options []metric.Int64GaugeOption -} { - var calls []struct { - Name string - Options []metric.Int64GaugeOption - } - mock.lockNewInt64Gauge.RLock() - calls = mock.calls.NewInt64Gauge - mock.lockNewInt64Gauge.RUnlock() - return calls -} - -// NewInt64Histogram calls NewInt64HistogramFunc. -func (mock *ProviderMock) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { - if mock.NewInt64HistogramFunc == nil { - panic("ProviderMock.NewInt64HistogramFunc: method is nil but Provider.NewInt64Histogram was just called") - } - callInfo := struct { - Name string - Options []metric.Int64HistogramOption - }{ - Name: name, - Options: options, - } - mock.lockNewInt64Histogram.Lock() - mock.calls.NewInt64Histogram = append(mock.calls.NewInt64Histogram, callInfo) - mock.lockNewInt64Histogram.Unlock() - return mock.NewInt64HistogramFunc(name, options...) -} - -// NewInt64HistogramCalls gets all the calls that were made to NewInt64Histogram. -// Check the length with: -// -// len(mockedProvider.NewInt64HistogramCalls()) -func (mock *ProviderMock) NewInt64HistogramCalls() []struct { - Name string - Options []metric.Int64HistogramOption -} { - var calls []struct { - Name string - Options []metric.Int64HistogramOption - } - mock.lockNewInt64Histogram.RLock() - calls = mock.calls.NewInt64Histogram - mock.lockNewInt64Histogram.RUnlock() - return calls -} - -// NewInt64UpDownCounter calls NewInt64UpDownCounterFunc. -func (mock *ProviderMock) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { - if mock.NewInt64UpDownCounterFunc == nil { - panic("ProviderMock.NewInt64UpDownCounterFunc: method is nil but Provider.NewInt64UpDownCounter was just called") - } - callInfo := struct { - Name string - Options []metric.Int64UpDownCounterOption - }{ - Name: name, - Options: options, - } - mock.lockNewInt64UpDownCounter.Lock() - mock.calls.NewInt64UpDownCounter = append(mock.calls.NewInt64UpDownCounter, callInfo) - mock.lockNewInt64UpDownCounter.Unlock() - return mock.NewInt64UpDownCounterFunc(name, options...) -} - -// NewInt64UpDownCounterCalls gets all the calls that were made to NewInt64UpDownCounter. -// Check the length with: -// -// len(mockedProvider.NewInt64UpDownCounterCalls()) -func (mock *ProviderMock) NewInt64UpDownCounterCalls() []struct { - Name string - Options []metric.Int64UpDownCounterOption -} { - var calls []struct { - Name string - Options []metric.Int64UpDownCounterOption - } - mock.lockNewInt64UpDownCounter.RLock() - calls = mock.calls.NewInt64UpDownCounter - mock.lockNewInt64UpDownCounter.RUnlock() - return calls -} - -// Shutdown calls ShutdownFunc. -func (mock *ProviderMock) Shutdown(ctx context.Context) error { - if mock.ShutdownFunc == nil { - panic("ProviderMock.ShutdownFunc: method is nil but Provider.Shutdown was just called") - } - callInfo := struct { - Ctx context.Context - }{ - Ctx: ctx, - } - mock.lockShutdown.Lock() - mock.calls.Shutdown = append(mock.calls.Shutdown, callInfo) - mock.lockShutdown.Unlock() - return mock.ShutdownFunc(ctx) -} - -// ShutdownCalls gets all the calls that were made to Shutdown. -// Check the length with: -// -// len(mockedProvider.ShutdownCalls()) -func (mock *ProviderMock) ShutdownCalls() []struct { - Ctx context.Context -} { - var calls []struct { - Ctx context.Context - } - mock.lockShutdown.RLock() - calls = mock.calls.Shutdown - mock.lockShutdown.RUnlock() - return calls -} diff --git a/cache/config/mocks_gen_test.go b/cache/config/mocks_gen_test.go deleted file mode 100644 index 7ad25d3..0000000 --- a/cache/config/mocks_gen_test.go +++ /dev/null @@ -1,5 +0,0 @@ -package config - -// Regenerate mocks via `go generate ./cache/config/...`. - -//go:generate go tool github.com/matryer/moq -out metricsprovider_mock_test.go -pkg config -rm -fmt goimports ../../observability/metrics Provider diff --git a/cache/redis/mocks_gen_test.go b/cache/redis/mocks_gen_test.go index b0a8e9d..8748962 100644 --- a/cache/redis/mocks_gen_test.go +++ b/cache/redis/mocks_gen_test.go @@ -1,7 +1,10 @@ package redis -// Regenerate mocks via `go generate ./cache/redis/...`. +// Regenerate the redisClient mock via `go generate ./cache/redis/...`. The +// redisClient interface is unexported (it's a test seam), so its mock lives +// in-package as a *_test.go file rather than under a sibling mock package. +// Mocks for the external interfaces (metrics.Provider, circuitbreaking.CircuitBreaker) +// live alongside those interfaces in observability/metrics/mock2 and +// circuitbreaking/mock2. //go:generate go tool github.com/matryer/moq -out redisclient_mock_test.go -pkg redis -rm -fmt goimports . redisClient -//go:generate go tool github.com/matryer/moq -out metricsprovider_mock_test.go -pkg redis -rm -fmt goimports ../../observability/metrics Provider -//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock_test.go -pkg redis -rm -fmt goimports ../../circuitbreaking CircuitBreaker diff --git a/cache/redis/unit_test.go b/cache/redis/unit_test.go index e954292..7ce1cc4 100644 --- a/cache/redis/unit_test.go +++ b/cache/redis/unit_test.go @@ -9,8 +9,10 @@ import ( "time" "github.com/verygoodsoftwarenotvirus/platform/v5/cache" + circuitbreakingmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock2" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + metricsmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock2" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/go-redis/redis/v8" @@ -28,7 +30,7 @@ func gobEncodeExample(t *testing.T, e *example) string { return buf.String() } -func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *CircuitBreakerMock) { +func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *circuitbreakingmock2.CircuitBreakerMock) { t.Helper() mp := metrics.NewNoopMetricsProvider() @@ -52,7 +54,7 @@ func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *C must.NoError(t, err) client := &redisClientMock{} - cb := &CircuitBreakerMock{} + cb := &circuitbreakingmock2.CircuitBreakerMock{} return &redisCacheImpl[example]{ logger: logging.NewNoopLogger(), @@ -75,11 +77,12 @@ type counterResult struct { err error } -// newCounterProviderMock returns a ProviderMock whose NewInt64Counter implementation -// looks up the result keyed on the counter name. Unknown names fail the test. -func newCounterProviderMock(t *testing.T, results map[string]counterResult) *ProviderMock { +// newCounterProviderMock returns a metrics.Provider mock whose NewInt64Counter +// implementation looks up the result keyed on the counter name. Unknown names +// fail the test. +func newCounterProviderMock(t *testing.T, results map[string]counterResult) *metricsmock2.ProviderMock { t.Helper() - return &ProviderMock{ + return &metricsmock2.ProviderMock{ NewInt64CounterFunc: func(metricName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { res, ok := results[metricName] if !ok { @@ -247,7 +250,7 @@ func TestNewRedisCache(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") must.NoError(t, histErr) - mp := &ProviderMock{ + mp := &metricsmock2.ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { return metrics.Int64CounterForTest(t, "x"), nil }, diff --git a/cache/redis/circuitbreaker_mock_test.go b/circuitbreaking/mock2/circuitbreaker_mock.go similarity index 99% rename from cache/redis/circuitbreaker_mock_test.go rename to circuitbreaking/mock2/circuitbreaker_mock.go index 8b474ee..23de9f7 100644 --- a/cache/redis/circuitbreaker_mock_test.go +++ b/circuitbreaking/mock2/circuitbreaker_mock.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package redis +package mock2 import ( "sync" diff --git a/circuitbreaking/mock2/doc.go b/circuitbreaking/mock2/doc.go new file mode 100644 index 0000000..3912fcf --- /dev/null +++ b/circuitbreaking/mock2/doc.go @@ -0,0 +1,10 @@ +// Package mock2 provides moq-generated mock implementations of interfaces in +// the circuitbreaking package. It exists alongside the hand-written +// testify-based package circuitbreaking/mock and is a pilot of the +// matryer/moq workflow; consumers that want the moq style should import this +// package instead of circuitbreaking/mock. +package mock2 + +// Regenerate via `go generate ./circuitbreaking/mock2/`. + +//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock.go -pkg mock2 -rm -fmt goimports .. CircuitBreaker:CircuitBreakerMock diff --git a/observability/metrics/mock2/doc.go b/observability/metrics/mock2/doc.go new file mode 100644 index 0000000..f3d3315 --- /dev/null +++ b/observability/metrics/mock2/doc.go @@ -0,0 +1,10 @@ +// Package mock2 provides moq-generated mock implementations of interfaces in +// the observability/metrics package. It exists alongside the hand-written +// testify-based package observability/metrics/mock and is a pilot of the +// matryer/moq workflow; consumers that want the moq style should import this +// package instead of observability/metrics/mock. +package mock2 + +// Regenerate via `go generate ./observability/metrics/mock2/`. + +//go:generate go tool github.com/matryer/moq -out provider_mock.go -pkg mock2 -rm -fmt goimports .. Provider:ProviderMock diff --git a/cache/redis/metricsprovider_mock_test.go b/observability/metrics/mock2/provider_mock.go similarity index 99% rename from cache/redis/metricsprovider_mock_test.go rename to observability/metrics/mock2/provider_mock.go index 247d6e8..2272295 100644 --- a/cache/redis/metricsprovider_mock_test.go +++ b/observability/metrics/mock2/provider_mock.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package redis +package mock2 import ( "context" From 23d8ffd9c3b8a368bdd8e9306bfdafdb970c1959 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Fri, 10 Apr 2026 18:49:28 -0500 Subject: [PATCH 05/12] test: more rejiggering --- .github/workflows/generated_files.yaml | 43 ++ CLAUDE.md | 1 + Makefile | 6 + cache/mock/cache_mock.go | 233 ++++++ cache/mock/doc.go | 8 + cache/redis/config_test.go | 45 ++ cache/redis/redis_test.go | 521 +++++++++++++ cache/redis/unit_test.go | 569 -------------- .../text/elasticsearch/elasticsearch_test.go | 300 ++++++++ search/text/elasticsearch/index_test.go | 425 +++++++++++ search/text/elasticsearch/unit_test.go | 720 ------------------ 11 files changed, 1582 insertions(+), 1289 deletions(-) create mode 100644 .github/workflows/generated_files.yaml create mode 100644 cache/mock/cache_mock.go create mode 100644 cache/mock/doc.go create mode 100644 cache/redis/config_test.go delete mode 100644 cache/redis/unit_test.go delete mode 100644 search/text/elasticsearch/unit_test.go diff --git a/.github/workflows/generated_files.yaml b/.github/workflows/generated_files.yaml new file mode 100644 index 0000000..e8e532b --- /dev/null +++ b/.github/workflows/generated_files.yaml @@ -0,0 +1,43 @@ +--- +on: + workflow_dispatch: + pull_request: + paths: + - '**/*.go' + - go.mod + - go.sum + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +name: generated files +jobs: + golang: + timeout-minutes: 10 + runs-on: ubuntu-latest + strategy: + matrix: + go-version: ['1.26.x'] + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Install Go + uses: actions/setup-go@v6 + with: + go-version: ${{ matrix.go-version }} + cache: true + cache-dependency-path: go.sum + + - name: Regenerate + run: make generate + + - name: Verify no drift + run: | + if [ -n "$(git status --porcelain)" ]; then + echo "go generate produced a diff. Run 'make generate' locally and commit the result." + git status --short + git diff + exit 1 + fi diff --git a/CLAUDE.md b/CLAUDE.md index 4f0081c..3ef2ce6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -14,6 +14,7 @@ make lint # Run golangci-lint (Docker) + shellcheck make format lint # Typical workflow: format then lint make test # Run tests (race detector, shuffle, failfast) make build # Build all packages +make generate # Regenerate moq mocks after changing any mocked interface make setup # Install dev tools + vendor deps make revendor # Clean and re-vendor dependencies ``` diff --git a/Makefile b/Makefile index e236d4e..caf7e05 100644 --- a/Makefile +++ b/Makefile @@ -89,6 +89,12 @@ shellcheck: .PHONY: lint lint: golang_lint shellcheck +## GENERATION + +.PHONY: generate +generate: + go generate ./... + ## EXECUTION .PHONY: build diff --git a/cache/mock/cache_mock.go b/cache/mock/cache_mock.go new file mode 100644 index 0000000..ae29efd --- /dev/null +++ b/cache/mock/cache_mock.go @@ -0,0 +1,233 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mock + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/cache" +) + +// Ensure, that CacheMock does implement cache.Cache. +// If this is not the case, regenerate this file with moq. +var _ cache.Cache[any] = &CacheMock[any]{} + +// CacheMock is a mock implementation of cache.Cache. +// +// func TestSomethingThatUsesCache(t *testing.T) { +// +// // make and configure a mocked cache.Cache +// mockedCache := &CacheMock{ +// DeleteFunc: func(ctx context.Context, key string) error { +// panic("mock out the Delete method") +// }, +// GetFunc: func(ctx context.Context, key string) (*T, error) { +// panic("mock out the Get method") +// }, +// PingFunc: func(ctx context.Context) error { +// panic("mock out the Ping method") +// }, +// SetFunc: func(ctx context.Context, key string, value *T) error { +// panic("mock out the Set method") +// }, +// } +// +// // use mockedCache in code that requires cache.Cache +// // and then make assertions. +// +// } +type CacheMock[T any] struct { + // DeleteFunc mocks the Delete method. + DeleteFunc func(ctx context.Context, key string) error + + // GetFunc mocks the Get method. + GetFunc func(ctx context.Context, key string) (*T, error) + + // PingFunc mocks the Ping method. + PingFunc func(ctx context.Context) error + + // SetFunc mocks the Set method. + SetFunc func(ctx context.Context, key string, value *T) error + + // calls tracks calls to the methods. + calls struct { + // Delete holds details about calls to the Delete method. + Delete []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + } + // Get holds details about calls to the Get method. + Get []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + } + // Ping holds details about calls to the Ping method. + Ping []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // Set holds details about calls to the Set method. + Set []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + // Value is the value argument value. + Value *T + } + } + lockDelete sync.RWMutex + lockGet sync.RWMutex + lockPing sync.RWMutex + lockSet sync.RWMutex +} + +// Delete calls DeleteFunc. +func (mock *CacheMock[T]) Delete(ctx context.Context, key string) error { + if mock.DeleteFunc == nil { + panic("CacheMock.DeleteFunc: method is nil but Cache.Delete was just called") + } + callInfo := struct { + Ctx context.Context + Key string + }{ + Ctx: ctx, + Key: key, + } + mock.lockDelete.Lock() + mock.calls.Delete = append(mock.calls.Delete, callInfo) + mock.lockDelete.Unlock() + return mock.DeleteFunc(ctx, key) +} + +// DeleteCalls gets all the calls that were made to Delete. +// Check the length with: +// +// len(mockedCache.DeleteCalls()) +func (mock *CacheMock[T]) DeleteCalls() []struct { + Ctx context.Context + Key string +} { + var calls []struct { + Ctx context.Context + Key string + } + mock.lockDelete.RLock() + calls = mock.calls.Delete + mock.lockDelete.RUnlock() + return calls +} + +// Get calls GetFunc. +func (mock *CacheMock[T]) Get(ctx context.Context, key string) (*T, error) { + if mock.GetFunc == nil { + panic("CacheMock.GetFunc: method is nil but Cache.Get was just called") + } + callInfo := struct { + Ctx context.Context + Key string + }{ + Ctx: ctx, + Key: key, + } + mock.lockGet.Lock() + mock.calls.Get = append(mock.calls.Get, callInfo) + mock.lockGet.Unlock() + return mock.GetFunc(ctx, key) +} + +// GetCalls gets all the calls that were made to Get. +// Check the length with: +// +// len(mockedCache.GetCalls()) +func (mock *CacheMock[T]) GetCalls() []struct { + Ctx context.Context + Key string +} { + var calls []struct { + Ctx context.Context + Key string + } + mock.lockGet.RLock() + calls = mock.calls.Get + mock.lockGet.RUnlock() + return calls +} + +// Ping calls PingFunc. +func (mock *CacheMock[T]) Ping(ctx context.Context) error { + if mock.PingFunc == nil { + panic("CacheMock.PingFunc: method is nil but Cache.Ping was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockPing.Lock() + mock.calls.Ping = append(mock.calls.Ping, callInfo) + mock.lockPing.Unlock() + return mock.PingFunc(ctx) +} + +// PingCalls gets all the calls that were made to Ping. +// Check the length with: +// +// len(mockedCache.PingCalls()) +func (mock *CacheMock[T]) PingCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockPing.RLock() + calls = mock.calls.Ping + mock.lockPing.RUnlock() + return calls +} + +// Set calls SetFunc. +func (mock *CacheMock[T]) Set(ctx context.Context, key string, value *T) error { + if mock.SetFunc == nil { + panic("CacheMock.SetFunc: method is nil but Cache.Set was just called") + } + callInfo := struct { + Ctx context.Context + Key string + Value *T + }{ + Ctx: ctx, + Key: key, + Value: value, + } + mock.lockSet.Lock() + mock.calls.Set = append(mock.calls.Set, callInfo) + mock.lockSet.Unlock() + return mock.SetFunc(ctx, key, value) +} + +// SetCalls gets all the calls that were made to Set. +// Check the length with: +// +// len(mockedCache.SetCalls()) +func (mock *CacheMock[T]) SetCalls() []struct { + Ctx context.Context + Key string + Value *T +} { + var calls []struct { + Ctx context.Context + Key string + Value *T + } + mock.lockSet.RLock() + calls = mock.calls.Set + mock.lockSet.RUnlock() + return calls +} diff --git a/cache/mock/doc.go b/cache/mock/doc.go new file mode 100644 index 0000000..c87ddd0 --- /dev/null +++ b/cache/mock/doc.go @@ -0,0 +1,8 @@ +// Package mock provides moq-generated mock implementations of interfaces in +// the cache package. The primary consumer is external tests that need to mock +// cache.Cache[T] — cache's own tests do not depend on this package. +package mock + +// Regenerate via `go generate ./cache/mock/`. + +//go:generate go tool github.com/matryer/moq -out cache_mock.go -pkg mock -rm -fmt goimports .. Cache:CacheMock diff --git a/cache/redis/config_test.go b/cache/redis/config_test.go new file mode 100644 index 0000000..5b853d6 --- /dev/null +++ b/cache/redis/config_test.go @@ -0,0 +1,45 @@ +package redis + +import ( + "testing" + + "github.com/shoenig/test" +) + +func TestConfig_ValidateWithContext(T *testing.T) { + T.Parallel() + + T.Run("standard", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + cfg := &Config{ + QueueAddresses: []string{"localhost:6379"}, + } + + test.NoError(t, cfg.ValidateWithContext(ctx)) + }) + + T.Run("with empty addresses", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + cfg := &Config{ + QueueAddresses: []string{}, + } + + test.Error(t, cfg.ValidateWithContext(ctx)) + }) + + T.Run("with nil addresses", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + + cfg := &Config{} + + test.Error(t, cfg.ValidateWithContext(ctx)) + }) +} diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 53299d6..c2eda40 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -1,14 +1,26 @@ package redis import ( + "bytes" "context" + "encoding/gob" + "errors" "strings" "testing" "time" + "github.com/verygoodsoftwarenotvirus/platform/v5/cache" + circuitbreakingmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock2" + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + metricsmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock2" + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + + "github.com/go-redis/redis/v8" "github.com/shoenig/test" "github.com/shoenig/test/must" rediscontainers "github.com/testcontainers/testcontainers-go/modules/redis" + "go.opentelemetry.io/otel/metric" ) const ( @@ -20,6 +32,78 @@ type example struct { Name string `json:"name"` } +func gobEncodeExample(t *testing.T, e *example) string { + t.Helper() + + var buf bytes.Buffer + must.NoError(t, gob.NewEncoder(&buf).Encode(e)) + + return buf.String() +} + +func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *circuitbreakingmock2.CircuitBreakerMock) { + t.Helper() + + mp := metrics.NewNoopMetricsProvider() + + hitCounter, err := mp.NewInt64Counter("test_hits") + must.NoError(t, err) + + missCounter, err := mp.NewInt64Counter("test_misses") + must.NoError(t, err) + + setCounter, err := mp.NewInt64Counter("test_sets") + must.NoError(t, err) + + delCounter, err := mp.NewInt64Counter("test_deletes") + must.NoError(t, err) + + errCounter, err := mp.NewInt64Counter("test_errors") + must.NoError(t, err) + + latencyHist, err := mp.NewFloat64Histogram("test_latency") + must.NoError(t, err) + + client := &redisClientMock{} + cb := &circuitbreakingmock2.CircuitBreakerMock{} + + return &redisCacheImpl[example]{ + logger: logging.NewNoopLogger(), + tracer: tracing.NewNamedTracer(nil, "test"), + cacheHitCounter: hitCounter, + cacheMissCounter: missCounter, + cacheSetCounter: setCounter, + cacheDelCounter: delCounter, + cacheErrCounter: errCounter, + latencyHist: latencyHist, + client: client, + circuitBreaker: cb, + expiration: time.Minute, + }, client, cb +} + +// counterResult bundles the values a mocked NewInt64Counter call returns. +type counterResult struct { + counter metrics.Int64Counter + err error +} + +// newCounterProviderMock returns a metrics.Provider mock whose NewInt64Counter +// implementation looks up the result keyed on the counter name. Unknown names +// fail the test. +func newCounterProviderMock(t *testing.T, results map[string]counterResult) *metricsmock2.ProviderMock { + t.Helper() + return &metricsmock2.ProviderMock{ + NewInt64CounterFunc: func(metricName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + res, ok := results[metricName] + if !ok { + t.Fatalf("unexpected NewInt64Counter call: %q", metricName) + } + return res.counter, res.err + }, + } +} + func buildContainerBackedRedisConfig(t *testing.T) (config *Config, shutdownFunction func(context.Context) error) { t.Helper() @@ -53,6 +137,143 @@ func buildContainerBackedRedisConfig(t *testing.T) (config *Config, shutdownFunc return cfg, shutdownFunc } +func TestNewRedisCache(T *testing.T) { + T.Parallel() + + okCounter := func() metrics.Int64Counter { return metrics.Int64CounterForTest(T, "x") } + + T.Run("with single address", func(t *testing.T) { + t.Parallel() + + cfg := &Config{QueueAddresses: []string{"localhost:6379"}} + + c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, nil, nil) + must.NoError(t, err) + test.NotNil(t, c) + }) + + T.Run("with multiple addresses", func(t *testing.T) { + t.Parallel() + + cfg := &Config{QueueAddresses: []string{"localhost:6379", "localhost:6380"}} + + c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, nil, nil) + must.NoError(t, err) + test.NotNil(t, c) + }) + + T.Run("with error creating cache hit counter", func(t *testing.T) { + t.Parallel() + + cfg := &Config{QueueAddresses: []string{"localhost:6379"}} + + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter(), err: errors.New("counter error")}, + }) + + c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) + }) + + T.Run("with error creating cache miss counter", func(t *testing.T) { + t.Parallel() + + cfg := &Config{QueueAddresses: []string{"localhost:6379"}} + + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter()}, + name + "_cache_misses": {counter: okCounter(), err: errors.New("counter error")}, + }) + + c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + }) + + T.Run("with error creating cache set counter", func(t *testing.T) { + t.Parallel() + + cfg := &Config{QueueAddresses: []string{"localhost:6379"}} + + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter()}, + name + "_cache_misses": {counter: okCounter()}, + name + "_cache_sets": {counter: okCounter(), err: errors.New("counter error")}, + }) + + c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 3, mp.NewInt64CounterCalls()) + }) + + T.Run("with error creating cache delete counter", func(t *testing.T) { + t.Parallel() + + cfg := &Config{QueueAddresses: []string{"localhost:6379"}} + + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter()}, + name + "_cache_misses": {counter: okCounter()}, + name + "_cache_sets": {counter: okCounter()}, + name + "_cache_deletes": {counter: okCounter(), err: errors.New("counter error")}, + }) + + c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 4, mp.NewInt64CounterCalls()) + }) + + T.Run("with error creating cache error counter", func(t *testing.T) { + t.Parallel() + + cfg := &Config{QueueAddresses: []string{"localhost:6379"}} + + mp := newCounterProviderMock(t, map[string]counterResult{ + name + "_cache_hits": {counter: okCounter()}, + name + "_cache_misses": {counter: okCounter()}, + name + "_cache_sets": {counter: okCounter()}, + name + "_cache_deletes": {counter: okCounter()}, + name + "_cache_errors": {counter: okCounter(), err: errors.New("counter error")}, + }) + + c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 5, mp.NewInt64CounterCalls()) + }) + + T.Run("with error creating latency histogram", func(t *testing.T) { + t.Parallel() + + cfg := &Config{QueueAddresses: []string{"localhost:6379"}} + + noopMP := metrics.NewNoopMetricsProvider() + h, histErr := noopMP.NewFloat64Histogram("test") + must.NoError(t, histErr) + + mp := &metricsmock2.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(metricName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + test.EqOp(t, name+"_cache_latency_ms", metricName) + return h, errors.New("histogram error") + }, + } + + c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) + test.Error(t, err) + test.Nil(t, c) + test.SliceLen(t, 5, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) + }) +} + func Test_redisCacheImpl_Get(T *testing.T) { T.Parallel() @@ -77,6 +298,99 @@ func Test_redisCacheImpl_Get(T *testing.T) { }) } +func Test_redisCacheImpl_Get_Unit(T *testing.T) { + T.Parallel() + + T.Run("standard", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, cb := buildTestImpl(t) + + expected := &example{Name: t.Name()} + encoded := gobEncodeExample(t, expected) + + cb.CannotProceedFunc = func() bool { return false } + cb.SucceededFunc = func() {} + + client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { + test.EqOp(t, exampleKey, key) + cmd := redis.NewStringCmd(ctx) + cmd.SetVal(encoded) + return cmd + } + + actual, err := impl.Get(ctx, exampleKey) + test.NoError(t, err) + test.Eq(t, expected, actual) + + test.SliceLen(t, 1, client.GetCalls()) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) + }) + + T.Run("when circuit breaker cannot proceed", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, _, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return true } + + actual, err := impl.Get(ctx, exampleKey) + test.ErrorIs(t, err, cache.ErrNotFound) + test.Nil(t, actual) + + test.SliceLen(t, 1, cb.CannotProceedCalls()) + }) + + T.Run("with redis error", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return false } + cb.FailedFunc = func() {} + + client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { + test.EqOp(t, exampleKey, key) + cmd := redis.NewStringCmd(ctx) + cmd.SetErr(errors.New("redis error")) + return cmd + } + + actual, err := impl.Get(ctx, exampleKey) + test.Error(t, err) + test.Nil(t, actual) + + test.SliceLen(t, 1, client.GetCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) + }) + + T.Run("with decode error", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return false } + + client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { + test.EqOp(t, exampleKey, key) + cmd := redis.NewStringCmd(ctx) + cmd.SetVal("not valid gob data") + return cmd + } + + actual, err := impl.Get(ctx, exampleKey) + test.Error(t, err) + test.Nil(t, actual) + + test.SliceLen(t, 1, client.GetCalls()) + }) +} + func Test_redisCacheImpl_Set(T *testing.T) { T.Parallel() @@ -97,6 +411,73 @@ func Test_redisCacheImpl_Set(T *testing.T) { }) } +func Test_redisCacheImpl_Set_Unit(T *testing.T) { + T.Parallel() + + T.Run("standard", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return false } + cb.SucceededFunc = func() {} + + client.SetFunc = func(_ context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { + test.EqOp(t, exampleKey, key) + test.EqOp(t, time.Minute, expiration) + _, isString := value.(string) + test.True(t, isString) + cmd := redis.NewStatusCmd(ctx) + cmd.SetVal("OK") + return cmd + } + + err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) + test.NoError(t, err) + + test.SliceLen(t, 1, client.SetCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) + }) + + T.Run("when circuit breaker cannot proceed", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, _, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return true } + + err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) + test.NoError(t, err) + + test.SliceLen(t, 1, cb.CannotProceedCalls()) + }) + + T.Run("with redis error", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return false } + cb.FailedFunc = func() {} + + client.SetFunc = func(_ context.Context, key string, _ any, _ time.Duration) *redis.StatusCmd { + test.EqOp(t, exampleKey, key) + cmd := redis.NewStatusCmd(ctx) + cmd.SetErr(errors.New("redis error")) + return cmd + } + + err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) + test.Error(t, err) + + test.SliceLen(t, 1, client.SetCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) + }) +} + func Test_redisCacheImpl_Delete(T *testing.T) { T.Parallel() @@ -118,3 +499,143 @@ func Test_redisCacheImpl_Delete(T *testing.T) { test.NoError(t, c.Delete(ctx, exampleKey)) }) } + +func Test_redisCacheImpl_Delete_Unit(T *testing.T) { + T.Parallel() + + T.Run("standard", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return false } + cb.SucceededFunc = func() {} + + client.DelFunc = func(_ context.Context, keys ...string) *redis.IntCmd { + test.Eq(t, []string{exampleKey}, keys) + cmd := redis.NewIntCmd(ctx) + cmd.SetVal(1) + return cmd + } + + err := impl.Delete(ctx, exampleKey) + test.NoError(t, err) + + test.SliceLen(t, 1, client.DelCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) + }) + + T.Run("when circuit breaker cannot proceed", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, _, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return true } + + err := impl.Delete(ctx, exampleKey) + test.NoError(t, err) + + test.SliceLen(t, 1, cb.CannotProceedCalls()) + }) + + T.Run("with redis error", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, cb := buildTestImpl(t) + + cb.CannotProceedFunc = func() bool { return false } + cb.FailedFunc = func() {} + + client.DelFunc = func(_ context.Context, _ ...string) *redis.IntCmd { + cmd := redis.NewIntCmd(ctx) + cmd.SetErr(errors.New("redis error")) + return cmd + } + + err := impl.Delete(ctx, exampleKey) + test.Error(t, err) + + test.SliceLen(t, 1, client.DelCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) + }) +} + +func Test_redisCacheImpl_Ping_Unit(T *testing.T) { + T.Parallel() + + T.Run("standard", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, _ := buildTestImpl(t) + + client.PingFunc = func(_ context.Context) *redis.StatusCmd { + cmd := redis.NewStatusCmd(ctx) + cmd.SetVal("PONG") + return cmd + } + + test.NoError(t, impl.Ping(ctx)) + test.SliceLen(t, 1, client.PingCalls()) + }) + + T.Run("with error", func(t *testing.T) { + t.Parallel() + + ctx := t.Context() + impl, client, _ := buildTestImpl(t) + + client.PingFunc = func(_ context.Context) *redis.StatusCmd { + cmd := redis.NewStatusCmd(ctx) + cmd.SetErr(errors.New("connection refused")) + return cmd + } + + test.Error(t, impl.Ping(ctx)) + test.SliceLen(t, 1, client.PingCalls()) + }) +} + +func Test_buildRedisClient(T *testing.T) { + T.Parallel() + + T.Run("with single address", func(t *testing.T) { + t.Parallel() + + cfg := &Config{ + QueueAddresses: []string{"localhost:6379"}, + Username: "user", + Password: "pass", + } + + c := buildRedisClient(cfg) + test.NotNil(t, c) + }) + + T.Run("with multiple addresses", func(t *testing.T) { + t.Parallel() + + cfg := &Config{ + QueueAddresses: []string{"localhost:6379", "localhost:6380"}, + Username: "user", + Password: "pass", + } + + c := buildRedisClient(cfg) + test.NotNil(t, c) + }) + + T.Run("with no addresses", func(t *testing.T) { + t.Parallel() + + cfg := &Config{ + QueueAddresses: []string{}, + } + + c := buildRedisClient(cfg) + test.Nil(t, c) + }) +} diff --git a/cache/redis/unit_test.go b/cache/redis/unit_test.go deleted file mode 100644 index 7ce1cc4..0000000 --- a/cache/redis/unit_test.go +++ /dev/null @@ -1,569 +0,0 @@ -package redis - -import ( - "bytes" - "context" - "encoding/gob" - "errors" - "testing" - "time" - - "github.com/verygoodsoftwarenotvirus/platform/v5/cache" - circuitbreakingmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock2" - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - metricsmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock2" - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - - "github.com/go-redis/redis/v8" - "github.com/shoenig/test" - "github.com/shoenig/test/must" - "go.opentelemetry.io/otel/metric" -) - -func gobEncodeExample(t *testing.T, e *example) string { - t.Helper() - - var buf bytes.Buffer - must.NoError(t, gob.NewEncoder(&buf).Encode(e)) - - return buf.String() -} - -func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *circuitbreakingmock2.CircuitBreakerMock) { - t.Helper() - - mp := metrics.NewNoopMetricsProvider() - - hitCounter, err := mp.NewInt64Counter("test_hits") - must.NoError(t, err) - - missCounter, err := mp.NewInt64Counter("test_misses") - must.NoError(t, err) - - setCounter, err := mp.NewInt64Counter("test_sets") - must.NoError(t, err) - - delCounter, err := mp.NewInt64Counter("test_deletes") - must.NoError(t, err) - - errCounter, err := mp.NewInt64Counter("test_errors") - must.NoError(t, err) - - latencyHist, err := mp.NewFloat64Histogram("test_latency") - must.NoError(t, err) - - client := &redisClientMock{} - cb := &circuitbreakingmock2.CircuitBreakerMock{} - - return &redisCacheImpl[example]{ - logger: logging.NewNoopLogger(), - tracer: tracing.NewNamedTracer(nil, "test"), - cacheHitCounter: hitCounter, - cacheMissCounter: missCounter, - cacheSetCounter: setCounter, - cacheDelCounter: delCounter, - cacheErrCounter: errCounter, - latencyHist: latencyHist, - client: client, - circuitBreaker: cb, - expiration: time.Minute, - }, client, cb -} - -// counterResult bundles the values a mocked NewInt64Counter call returns. -type counterResult struct { - counter metrics.Int64Counter - err error -} - -// newCounterProviderMock returns a metrics.Provider mock whose NewInt64Counter -// implementation looks up the result keyed on the counter name. Unknown names -// fail the test. -func newCounterProviderMock(t *testing.T, results map[string]counterResult) *metricsmock2.ProviderMock { - t.Helper() - return &metricsmock2.ProviderMock{ - NewInt64CounterFunc: func(metricName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { - res, ok := results[metricName] - if !ok { - t.Fatalf("unexpected NewInt64Counter call: %q", metricName) - } - return res.counter, res.err - }, - } -} - -func TestConfig_ValidateWithContext(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - - cfg := &Config{ - QueueAddresses: []string{"localhost:6379"}, - } - - test.NoError(t, cfg.ValidateWithContext(ctx)) - }) - - T.Run("with empty addresses", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - - cfg := &Config{ - QueueAddresses: []string{}, - } - - test.Error(t, cfg.ValidateWithContext(ctx)) - }) - - T.Run("with nil addresses", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - - cfg := &Config{} - - test.Error(t, cfg.ValidateWithContext(ctx)) - }) -} - -func TestNewRedisCache(T *testing.T) { - T.Parallel() - - okCounter := func() metrics.Int64Counter { return metrics.Int64CounterForTest(T, "x") } - - T.Run("with single address", func(t *testing.T) { - t.Parallel() - - cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - - c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, nil, nil) - must.NoError(t, err) - test.NotNil(t, c) - }) - - T.Run("with multiple addresses", func(t *testing.T) { - t.Parallel() - - cfg := &Config{QueueAddresses: []string{"localhost:6379", "localhost:6380"}} - - c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, nil, nil) - must.NoError(t, err) - test.NotNil(t, c) - }) - - T.Run("with error creating cache hit counter", func(t *testing.T) { - t.Parallel() - - cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - - mp := newCounterProviderMock(t, map[string]counterResult{ - name + "_cache_hits": {counter: okCounter(), err: errors.New("counter error")}, - }) - - c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - test.Error(t, err) - test.Nil(t, c) - test.SliceLen(t, 1, mp.NewInt64CounterCalls()) - }) - - T.Run("with error creating cache miss counter", func(t *testing.T) { - t.Parallel() - - cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - - mp := newCounterProviderMock(t, map[string]counterResult{ - name + "_cache_hits": {counter: okCounter()}, - name + "_cache_misses": {counter: okCounter(), err: errors.New("counter error")}, - }) - - c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - test.Error(t, err) - test.Nil(t, c) - test.SliceLen(t, 2, mp.NewInt64CounterCalls()) - }) - - T.Run("with error creating cache set counter", func(t *testing.T) { - t.Parallel() - - cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - - mp := newCounterProviderMock(t, map[string]counterResult{ - name + "_cache_hits": {counter: okCounter()}, - name + "_cache_misses": {counter: okCounter()}, - name + "_cache_sets": {counter: okCounter(), err: errors.New("counter error")}, - }) - - c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - test.Error(t, err) - test.Nil(t, c) - test.SliceLen(t, 3, mp.NewInt64CounterCalls()) - }) - - T.Run("with error creating cache delete counter", func(t *testing.T) { - t.Parallel() - - cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - - mp := newCounterProviderMock(t, map[string]counterResult{ - name + "_cache_hits": {counter: okCounter()}, - name + "_cache_misses": {counter: okCounter()}, - name + "_cache_sets": {counter: okCounter()}, - name + "_cache_deletes": {counter: okCounter(), err: errors.New("counter error")}, - }) - - c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - test.Error(t, err) - test.Nil(t, c) - test.SliceLen(t, 4, mp.NewInt64CounterCalls()) - }) - - T.Run("with error creating cache error counter", func(t *testing.T) { - t.Parallel() - - cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - - mp := newCounterProviderMock(t, map[string]counterResult{ - name + "_cache_hits": {counter: okCounter()}, - name + "_cache_misses": {counter: okCounter()}, - name + "_cache_sets": {counter: okCounter()}, - name + "_cache_deletes": {counter: okCounter()}, - name + "_cache_errors": {counter: okCounter(), err: errors.New("counter error")}, - }) - - c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - test.Error(t, err) - test.Nil(t, c) - test.SliceLen(t, 5, mp.NewInt64CounterCalls()) - }) - - T.Run("with error creating latency histogram", func(t *testing.T) { - t.Parallel() - - cfg := &Config{QueueAddresses: []string{"localhost:6379"}} - - noopMP := metrics.NewNoopMetricsProvider() - h, histErr := noopMP.NewFloat64Histogram("test") - must.NoError(t, histErr) - - mp := &metricsmock2.ProviderMock{ - NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { - return metrics.Int64CounterForTest(t, "x"), nil - }, - NewFloat64HistogramFunc: func(metricName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { - test.EqOp(t, name+"_cache_latency_ms", metricName) - return h, errors.New("histogram error") - }, - } - - c, err := NewRedisCache[example](cfg, time.Minute, nil, nil, mp, nil) - test.Error(t, err) - test.Nil(t, c) - test.SliceLen(t, 5, mp.NewInt64CounterCalls()) - test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) - }) -} - -func Test_redisCacheImpl_Get_Unit(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, cb := buildTestImpl(t) - - expected := &example{Name: t.Name()} - encoded := gobEncodeExample(t, expected) - - cb.CannotProceedFunc = func() bool { return false } - cb.SucceededFunc = func() {} - - client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { - test.EqOp(t, exampleKey, key) - cmd := redis.NewStringCmd(ctx) - cmd.SetVal(encoded) - return cmd - } - - actual, err := impl.Get(ctx, exampleKey) - test.NoError(t, err) - test.Eq(t, expected, actual) - - test.SliceLen(t, 1, client.GetCalls()) - test.SliceLen(t, 1, cb.CannotProceedCalls()) - test.SliceLen(t, 1, cb.SucceededCalls()) - }) - - T.Run("when circuit breaker cannot proceed", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, _, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return true } - - actual, err := impl.Get(ctx, exampleKey) - test.ErrorIs(t, err, cache.ErrNotFound) - test.Nil(t, actual) - - test.SliceLen(t, 1, cb.CannotProceedCalls()) - }) - - T.Run("with redis error", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return false } - cb.FailedFunc = func() {} - - client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { - test.EqOp(t, exampleKey, key) - cmd := redis.NewStringCmd(ctx) - cmd.SetErr(errors.New("redis error")) - return cmd - } - - actual, err := impl.Get(ctx, exampleKey) - test.Error(t, err) - test.Nil(t, actual) - - test.SliceLen(t, 1, client.GetCalls()) - test.SliceLen(t, 1, cb.FailedCalls()) - }) - - T.Run("with decode error", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return false } - - client.GetFunc = func(_ context.Context, key string) *redis.StringCmd { - test.EqOp(t, exampleKey, key) - cmd := redis.NewStringCmd(ctx) - cmd.SetVal("not valid gob data") - return cmd - } - - actual, err := impl.Get(ctx, exampleKey) - test.Error(t, err) - test.Nil(t, actual) - - test.SliceLen(t, 1, client.GetCalls()) - }) -} - -func Test_redisCacheImpl_Set_Unit(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return false } - cb.SucceededFunc = func() {} - - client.SetFunc = func(_ context.Context, key string, value any, expiration time.Duration) *redis.StatusCmd { - test.EqOp(t, exampleKey, key) - test.EqOp(t, time.Minute, expiration) - _, isString := value.(string) - test.True(t, isString) - cmd := redis.NewStatusCmd(ctx) - cmd.SetVal("OK") - return cmd - } - - err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) - test.NoError(t, err) - - test.SliceLen(t, 1, client.SetCalls()) - test.SliceLen(t, 1, cb.SucceededCalls()) - }) - - T.Run("when circuit breaker cannot proceed", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, _, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return true } - - err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) - test.NoError(t, err) - - test.SliceLen(t, 1, cb.CannotProceedCalls()) - }) - - T.Run("with redis error", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return false } - cb.FailedFunc = func() {} - - client.SetFunc = func(_ context.Context, key string, _ any, _ time.Duration) *redis.StatusCmd { - test.EqOp(t, exampleKey, key) - cmd := redis.NewStatusCmd(ctx) - cmd.SetErr(errors.New("redis error")) - return cmd - } - - err := impl.Set(ctx, exampleKey, &example{Name: t.Name()}) - test.Error(t, err) - - test.SliceLen(t, 1, client.SetCalls()) - test.SliceLen(t, 1, cb.FailedCalls()) - }) -} - -func Test_redisCacheImpl_Delete_Unit(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return false } - cb.SucceededFunc = func() {} - - client.DelFunc = func(_ context.Context, keys ...string) *redis.IntCmd { - test.Eq(t, []string{exampleKey}, keys) - cmd := redis.NewIntCmd(ctx) - cmd.SetVal(1) - return cmd - } - - err := impl.Delete(ctx, exampleKey) - test.NoError(t, err) - - test.SliceLen(t, 1, client.DelCalls()) - test.SliceLen(t, 1, cb.SucceededCalls()) - }) - - T.Run("when circuit breaker cannot proceed", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, _, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return true } - - err := impl.Delete(ctx, exampleKey) - test.NoError(t, err) - - test.SliceLen(t, 1, cb.CannotProceedCalls()) - }) - - T.Run("with redis error", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, cb := buildTestImpl(t) - - cb.CannotProceedFunc = func() bool { return false } - cb.FailedFunc = func() {} - - client.DelFunc = func(_ context.Context, _ ...string) *redis.IntCmd { - cmd := redis.NewIntCmd(ctx) - cmd.SetErr(errors.New("redis error")) - return cmd - } - - err := impl.Delete(ctx, exampleKey) - test.Error(t, err) - - test.SliceLen(t, 1, client.DelCalls()) - test.SliceLen(t, 1, cb.FailedCalls()) - }) -} - -func Test_redisCacheImpl_Ping_Unit(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, _ := buildTestImpl(t) - - client.PingFunc = func(_ context.Context) *redis.StatusCmd { - cmd := redis.NewStatusCmd(ctx) - cmd.SetVal("PONG") - return cmd - } - - test.NoError(t, impl.Ping(ctx)) - test.SliceLen(t, 1, client.PingCalls()) - }) - - T.Run("with error", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - impl, client, _ := buildTestImpl(t) - - client.PingFunc = func(_ context.Context) *redis.StatusCmd { - cmd := redis.NewStatusCmd(ctx) - cmd.SetErr(errors.New("connection refused")) - return cmd - } - - test.Error(t, impl.Ping(ctx)) - test.SliceLen(t, 1, client.PingCalls()) - }) -} - -func Test_buildRedisClient(T *testing.T) { - T.Parallel() - - T.Run("with single address", func(t *testing.T) { - t.Parallel() - - cfg := &Config{ - QueueAddresses: []string{"localhost:6379"}, - Username: "user", - Password: "pass", - } - - c := buildRedisClient(cfg) - test.NotNil(t, c) - }) - - T.Run("with multiple addresses", func(t *testing.T) { - t.Parallel() - - cfg := &Config{ - QueueAddresses: []string{"localhost:6379", "localhost:6380"}, - Username: "user", - Password: "pass", - } - - c := buildRedisClient(cfg) - test.NotNil(t, c) - }) - - T.Run("with no addresses", func(t *testing.T) { - t.Parallel() - - cfg := &Config{ - QueueAddresses: []string{}, - } - - c := buildRedisClient(cfg) - test.Nil(t, c) - }) -} diff --git a/search/text/elasticsearch/elasticsearch_test.go b/search/text/elasticsearch/elasticsearch_test.go index 41dd633..a3a5473 100644 --- a/search/text/elasticsearch/elasticsearch_test.go +++ b/search/text/elasticsearch/elasticsearch_test.go @@ -2,17 +2,23 @@ package elasticsearch import ( "context" + "fmt" + "net/http" + "net/http/httptest" "os" "strings" "testing" "time" + "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking" + mockcircuitbreaking "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock" cbnoop "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/noop" "github.com/verygoodsoftwarenotvirus/platform/v5/identifiers" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" elasticsearchcontainers "github.com/testcontainers/testcontainers-go/modules/elasticsearch" ) @@ -315,3 +321,297 @@ func TestElasticsearch_Container(T *testing.T) { assert.Equal(t, "unimplemented", im.Wipe(t.Context()).Error()) }) } + +func TestIndexManager_ensureIndices_CircuitBroken(T *testing.T) { + T.Parallel() + + T.Run("with broken circuit breaker", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(true) + + im := buildTestIndexManagerForUnit(t, cb) + + err := im.ensureIndices(context.Background()) + assert.Error(t, err) + assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with unreachable server", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerForUnit(t, cb) + + err := im.ensureIndices(context.Background()) + assert.Error(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_ensureIndices_Unit(T *testing.T) { + T.Parallel() + + T.Run("index exists", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Elastic-Product", "Elasticsearch") + if r.Method == http.MethodHead && r.URL.Path == "/test" { + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Succeeded").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + err := im.ensureIndices(context.Background()) + assert.NoError(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("index does not exist and create succeeds", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Elastic-Product", "Elasticsearch") + if r.Method == http.MethodHead && r.URL.Path == "/test" { + w.WriteHeader(http.StatusNotFound) + return + } + if r.Method == http.MethodPut && r.URL.Path == "/test" { + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `{"acknowledged":true}`) + return + } + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Succeeded").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + err := im.ensureIndices(context.Background()) + assert.NoError(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("index does not exist and create fails", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Elastic-Product", "Elasticsearch") + if r.Method == http.MethodHead && r.URL.Path == "/test" { + w.WriteHeader(http.StatusNotFound) + return + } + if r.Method == http.MethodPut && r.URL.Path == "/test" { + // close connection to cause an error + hj, ok := w.(http.Hijacker) + if ok { + conn, _, _ := hj.Hijack() + conn.Close() + } + return + } + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + err := im.ensureIndices(context.Background()) + assert.Error(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func Test_provideElasticsearchClient_Unit(T *testing.T) { + T.Parallel() + + T.Run("standard", func(t *testing.T) { + t.Parallel() + + cfg := &Config{ + Address: "http://localhost:9200", + } + + client, err := provideElasticsearchClient(cfg) + assert.NoError(t, err) + assert.NotNil(t, client) + }) + + T.Run("with credentials", func(t *testing.T) { + t.Parallel() + + cfg := &Config{ + Address: "http://localhost:9200", + Username: "elastic", + Password: "password", + } + + client, err := provideElasticsearchClient(cfg) + assert.NoError(t, err) + assert.NotNil(t, client) + }) +} + +func Test_elasticsearchIsReadyToInit_Unit(T *testing.T) { + T.Parallel() + + T.Run("returns false with unreachable server", func(t *testing.T) { + t.Parallel() + + cfg := &Config{ + Address: "http://localhost:19291", + } + + logger := logging.NewNoopLogger() + ready := elasticsearchIsReadyToInit(context.Background(), cfg, logger, 1) + // This will either return true (if the info request returns non-error) or false + // With unreachable server, the error path is taken but the condition is + // err != nil && res != nil && !res.IsError() which won't match when res is nil, + // so it falls through to the else branch and returns true. + // This is actually a bug in the code but we test the actual behavior. + assert.True(t, ready) + }) + + T.Run("returns true with reachable server", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `{"name":"node","cluster_name":"test","version":{"number":"8.10.2"}}`) + })) + t.Cleanup(server.Close) + + cfg := &Config{ + Address: server.URL, + } + + logger := logging.NewNoopLogger() + ready := elasticsearchIsReadyToInit(context.Background(), cfg, logger, 3) + assert.True(t, ready) + }) +} + +func TestProvideIndexManager_Unit(T *testing.T) { + T.Parallel() + + T.Run("succeeds with mock server", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + + // Info request from elasticsearchIsReadyToInit + if r.Method == http.MethodGet && r.URL.Path == "/" { + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `{"name":"node","cluster_name":"test","version":{"number":"8.10.2"}}`) + return + } + + // Index exists check from ensureIndices + if r.Method == http.MethodHead && r.URL.Path == "/test" { + w.WriteHeader(http.StatusOK) + return + } + + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + cfg := &Config{ + Address: server.URL, + } + + logger := logging.NewNoopLogger() + tracerProvider := tracing.NewNoopTracerProvider() + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Succeeded").Return() + + im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) + assert.NoError(t, err) + assert.NotNil(t, im) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("fails when ensureIndices fails", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + + // Info request succeeds + if r.Method == http.MethodGet && r.URL.Path == "/" { + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `{"name":"node","cluster_name":"test","version":{"number":"8.10.2"}}`) + return + } + + // Index existence check returns 404 + if r.Method == http.MethodHead && r.URL.Path == "/test" { + w.WriteHeader(http.StatusNotFound) + return + } + + // Index creation: close connection to trigger error + if r.Method == http.MethodPut && r.URL.Path == "/test" { + hj, ok := w.(http.Hijacker) + if ok { + conn, _, _ := hj.Hijack() + conn.Close() + } + return + } + + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + cfg := &Config{ + Address: server.URL, + } + + logger := logging.NewNoopLogger() + tracerProvider := tracing.NewNoopTracerProvider() + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) + assert.Error(t, err) + assert.Nil(t, im) + + mock.AssertExpectationsForObjects(t, cb) + }) +} diff --git a/search/text/elasticsearch/index_test.go b/search/text/elasticsearch/index_test.go index 568a2de..316589c 100644 --- a/search/text/elasticsearch/index_test.go +++ b/search/text/elasticsearch/index_test.go @@ -1,5 +1,23 @@ package elasticsearch +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking" + mockcircuitbreaking "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock" + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + + "github.com/elastic/go-elasticsearch/v8" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + type example struct { ID string `json:"id"` Name string `json:"name"` @@ -8,3 +26,410 @@ type example struct { type invalidJSON struct { Channel chan int `json:"channel"` } + +func buildTestIndexManagerForUnit(t *testing.T, cb circuitbreaking.CircuitBreaker) *indexManager[example] { + t.Helper() + + client, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{"http://localhost:19291"}, // intentionally wrong + }) + if err != nil { + t.Fatal(err) + } + + return &indexManager[example]{ + logger: logging.NewNoopLogger(), + tracer: tracing.NewTracerForTest("test"), + circuitBreaker: cb, + esClient: client, + indexName: "test", + } +} + +func buildTestIndexManagerWithServer(t *testing.T, server *httptest.Server, cb circuitbreaking.CircuitBreaker) *indexManager[example] { + t.Helper() + + client, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{server.URL}, + }) + if err != nil { + t.Fatal(err) + } + + return &indexManager[example]{ + logger: logging.NewNoopLogger(), + tracer: tracing.NewTracerForTest("test"), + circuitBreaker: cb, + esClient: client, + indexName: "test", + } +} + +func TestIndexManager_Index_CircuitBroken(T *testing.T) { + T.Parallel() + + T.Run("with broken circuit breaker", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(true) + + im := buildTestIndexManagerForUnit(t, cb) + + err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) + assert.Error(t, err) + assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with unmarshalable value", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + + im := buildTestIndexManagerForUnit(t, cb) + + err := im.Index(context.Background(), "id", make(chan int)) + assert.Error(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with unreachable server", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerForUnit(t, cb) + + err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) + assert.Error(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_Index_Unit(T *testing.T) { + T.Parallel() + + T.Run("with successful index", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusCreated) + _, _ = fmt.Fprint(w, `{"_index":"test","_id":"123","result":"created"}`) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Succeeded").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) + assert.NoError(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with non-success status code", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprint(w, `{"error":{"type":"mapper_parsing_exception","reason":"failed to parse"}}`) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) + assert.Error(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_Search_CircuitBroken(T *testing.T) { + T.Parallel() + + T.Run("with broken circuit breaker", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(true) + + im := buildTestIndexManagerForUnit(t, cb) + + results, err := im.Search(context.Background(), "query") + assert.Error(t, err) + assert.Nil(t, results) + assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with empty query", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + + im := buildTestIndexManagerForUnit(t, cb) + + results, err := im.Search(context.Background(), "") + assert.Error(t, err) + assert.Nil(t, results) + assert.Equal(t, ErrEmptyQueryProvided, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with unreachable server", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerForUnit(t, cb) + + results, err := im.Search(context.Background(), "test query") + assert.Error(t, err) + assert.Nil(t, results) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_Search_Unit(T *testing.T) { + T.Parallel() + + T.Run("with successful search results", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `{"hits":{"total":{"value":1},"hits":[{"_id":"123","_source":{"id":"123","name":"test"}}]}}`) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Succeeded").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + results, err := im.Search(context.Background(), "test") + assert.NoError(t, err) + require.Len(t, results, 1) + assert.Equal(t, "123", results[0].ID) + assert.Equal(t, "test", results[0].Name) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with error response", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprint(w, `{"error":{"type":"search_phase_execution_exception","reason":"all shards failed"}}`) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + // NOTE: the search function has a named return 'err' that is overwritten + // by the deferred res.Body.Close() call, so the error is lost. The code + // does exercise the IsError() branch and calls circuitBreaker.Failed(), + // but ultimately returns nil error due to the defer clobbering it. + results, err := im.Search(context.Background(), "test") + assert.NoError(t, err) + assert.Nil(t, results) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with invalid JSON in success response", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `not valid json`) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + // NOTE: same issue as error response test - the deferred res.Body.Close() + // overwrites the named return 'err' with nil. + results, err := im.Search(context.Background(), "test") + assert.NoError(t, err) + assert.Nil(t, results) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_Search_ErrorResponseDecodeFailure_Unit(T *testing.T) { + T.Parallel() + + T.Run("with invalid JSON in error response body", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprint(w, `this is not valid json`) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + // NOTE: the named return 'err' from the deferred res.Body.Close() clobbers + // the error, so this returns nil error despite the decode failure. + results, err := im.Search(context.Background(), "test") + assert.NoError(t, err) + assert.Nil(t, results) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_Search_SourceUnmarshalError_Unit(T *testing.T) { + T.Parallel() + + T.Run("with invalid source in hit", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `{"hits":{"total":{"value":1},"hits":[{"_id":"123","_source":"not a valid object"}]}}`) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + // NOTE: the named return 'err' from the deferred res.Body.Close() clobbers + // the error, so this returns nil error despite the unmarshal failure. + results, err := im.Search(context.Background(), "test") + assert.NoError(t, err) + assert.Nil(t, results) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_Delete_CircuitBroken(T *testing.T) { + T.Parallel() + + T.Run("with broken circuit breaker", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(true) + + im := buildTestIndexManagerForUnit(t, cb) + + err := im.Delete(context.Background(), "id") + assert.Error(t, err) + assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + + mock.AssertExpectationsForObjects(t, cb) + }) + + T.Run("with unreachable server", func(t *testing.T) { + t.Parallel() + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Failed").Return() + + im := buildTestIndexManagerForUnit(t, cb) + + err := im.Delete(context.Background(), "some-id") + assert.Error(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_Delete_Unit(T *testing.T) { + T.Parallel() + + T.Run("with successful delete", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Elastic-Product", "Elasticsearch") + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprint(w, `{"_index":"test","_id":"123","result":"deleted"}`) + })) + t.Cleanup(server.Close) + + cb := &mockcircuitbreaking.MockCircuitBreaker{} + cb.On("CannotProceed").Return(false) + cb.On("Succeeded").Return() + + im := buildTestIndexManagerWithServer(t, server, cb) + + err := im.Delete(context.Background(), "123") + assert.NoError(t, err) + + mock.AssertExpectationsForObjects(t, cb) + }) +} + +func TestIndexManager_Wipe_Unit(T *testing.T) { + T.Parallel() + + T.Run("returns unimplemented error", func(t *testing.T) { + t.Parallel() + + im := &indexManager[example]{} + + err := im.Wipe(context.Background()) + assert.Error(t, err) + assert.Equal(t, "unimplemented", err.Error()) + }) +} diff --git a/search/text/elasticsearch/unit_test.go b/search/text/elasticsearch/unit_test.go deleted file mode 100644 index d06aa33..0000000 --- a/search/text/elasticsearch/unit_test.go +++ /dev/null @@ -1,720 +0,0 @@ -package elasticsearch - -import ( - "context" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking" - mockcircuitbreaking "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock" - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - - "github.com/elastic/go-elasticsearch/v8" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -func buildTestIndexManagerForUnit(t *testing.T, cb circuitbreaking.CircuitBreaker) *indexManager[example] { - t.Helper() - - client, err := elasticsearch.NewClient(elasticsearch.Config{ - Addresses: []string{"http://localhost:19291"}, // intentionally wrong - }) - if err != nil { - t.Fatal(err) - } - - return &indexManager[example]{ - logger: logging.NewNoopLogger(), - tracer: tracing.NewTracerForTest("test"), - circuitBreaker: cb, - esClient: client, - indexName: "test", - } -} - -func buildTestIndexManagerWithServer(t *testing.T, server *httptest.Server, cb circuitbreaking.CircuitBreaker) *indexManager[example] { - t.Helper() - - client, err := elasticsearch.NewClient(elasticsearch.Config{ - Addresses: []string{server.URL}, - }) - if err != nil { - t.Fatal(err) - } - - return &indexManager[example]{ - logger: logging.NewNoopLogger(), - tracer: tracing.NewTracerForTest("test"), - circuitBreaker: cb, - esClient: client, - indexName: "test", - } -} - -func TestIndexManager_Index_CircuitBroken(T *testing.T) { - T.Parallel() - - T.Run("with broken circuit breaker", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) - - im := buildTestIndexManagerForUnit(t, cb) - - err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with unmarshalable value", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - - im := buildTestIndexManagerForUnit(t, cb) - - err := im.Index(context.Background(), "id", make(chan int)) - assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with unreachable server", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerForUnit(t, cb) - - err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) - assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_Search_CircuitBroken(T *testing.T) { - T.Parallel() - - T.Run("with broken circuit breaker", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) - - im := buildTestIndexManagerForUnit(t, cb) - - results, err := im.Search(context.Background(), "query") - assert.Error(t, err) - assert.Nil(t, results) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with empty query", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - - im := buildTestIndexManagerForUnit(t, cb) - - results, err := im.Search(context.Background(), "") - assert.Error(t, err) - assert.Nil(t, results) - assert.Equal(t, ErrEmptyQueryProvided, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with unreachable server", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerForUnit(t, cb) - - results, err := im.Search(context.Background(), "test query") - assert.Error(t, err) - assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_Delete_CircuitBroken(T *testing.T) { - T.Parallel() - - T.Run("with broken circuit breaker", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) - - im := buildTestIndexManagerForUnit(t, cb) - - err := im.Delete(context.Background(), "id") - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with unreachable server", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerForUnit(t, cb) - - err := im.Delete(context.Background(), "some-id") - assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_Wipe_Unit(T *testing.T) { - T.Parallel() - - T.Run("returns unimplemented error", func(t *testing.T) { - t.Parallel() - - im := &indexManager[example]{} - - err := im.Wipe(context.Background()) - assert.Error(t, err) - assert.Equal(t, "unimplemented", err.Error()) - }) -} - -func TestIndexManager_ensureIndices_CircuitBroken(T *testing.T) { - T.Parallel() - - T.Run("with broken circuit breaker", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) - - im := buildTestIndexManagerForUnit(t, cb) - - err := im.ensureIndices(context.Background()) - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with unreachable server", func(t *testing.T) { - t.Parallel() - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerForUnit(t, cb) - - err := im.ensureIndices(context.Background()) - assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_ensureIndices_Unit(T *testing.T) { - T.Parallel() - - T.Run("index exists", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Elastic-Product", "Elasticsearch") - if r.Method == http.MethodHead && r.URL.Path == "/test" { - w.WriteHeader(http.StatusOK) - return - } - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - err := im.ensureIndices(context.Background()) - assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("index does not exist and create succeeds", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Elastic-Product", "Elasticsearch") - if r.Method == http.MethodHead && r.URL.Path == "/test" { - w.WriteHeader(http.StatusNotFound) - return - } - if r.Method == http.MethodPut && r.URL.Path == "/test" { - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, `{"acknowledged":true}`) - return - } - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - err := im.ensureIndices(context.Background()) - assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("index does not exist and create fails", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Elastic-Product", "Elasticsearch") - if r.Method == http.MethodHead && r.URL.Path == "/test" { - w.WriteHeader(http.StatusNotFound) - return - } - if r.Method == http.MethodPut && r.URL.Path == "/test" { - // close connection to cause an error - hj, ok := w.(http.Hijacker) - if ok { - conn, _, _ := hj.Hijack() - conn.Close() - } - return - } - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - err := im.ensureIndices(context.Background()) - assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func Test_provideElasticsearchClient_Unit(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - cfg := &Config{ - Address: "http://localhost:9200", - } - - client, err := provideElasticsearchClient(cfg) - assert.NoError(t, err) - assert.NotNil(t, client) - }) - - T.Run("with credentials", func(t *testing.T) { - t.Parallel() - - cfg := &Config{ - Address: "http://localhost:9200", - Username: "elastic", - Password: "password", - } - - client, err := provideElasticsearchClient(cfg) - assert.NoError(t, err) - assert.NotNil(t, client) - }) -} - -func Test_elasticsearchIsReadyToInit_Unit(T *testing.T) { - T.Parallel() - - T.Run("returns false with unreachable server", func(t *testing.T) { - t.Parallel() - - cfg := &Config{ - Address: "http://localhost:19291", - } - - logger := logging.NewNoopLogger() - ready := elasticsearchIsReadyToInit(context.Background(), cfg, logger, 1) - // This will either return true (if the info request returns non-error) or false - // With unreachable server, the error path is taken but the condition is - // err != nil && res != nil && !res.IsError() which won't match when res is nil, - // so it falls through to the else branch and returns true. - // This is actually a bug in the code but we test the actual behavior. - assert.True(t, ready) - }) - - T.Run("returns true with reachable server", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, `{"name":"node","cluster_name":"test","version":{"number":"8.10.2"}}`) - })) - t.Cleanup(server.Close) - - cfg := &Config{ - Address: server.URL, - } - - logger := logging.NewNoopLogger() - ready := elasticsearchIsReadyToInit(context.Background(), cfg, logger, 3) - assert.True(t, ready) - }) -} - -func TestProvideIndexManager_Unit(T *testing.T) { - T.Parallel() - - T.Run("succeeds with mock server", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - - // Info request from elasticsearchIsReadyToInit - if r.Method == http.MethodGet && r.URL.Path == "/" { - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, `{"name":"node","cluster_name":"test","version":{"number":"8.10.2"}}`) - return - } - - // Index exists check from ensureIndices - if r.Method == http.MethodHead && r.URL.Path == "/test" { - w.WriteHeader(http.StatusOK) - return - } - - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(server.Close) - - cfg := &Config{ - Address: server.URL, - } - - logger := logging.NewNoopLogger() - tracerProvider := tracing.NewNoopTracerProvider() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() - - im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) - assert.NoError(t, err) - assert.NotNil(t, im) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("fails when ensureIndices fails", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - - // Info request succeeds - if r.Method == http.MethodGet && r.URL.Path == "/" { - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, `{"name":"node","cluster_name":"test","version":{"number":"8.10.2"}}`) - return - } - - // Index existence check returns 404 - if r.Method == http.MethodHead && r.URL.Path == "/test" { - w.WriteHeader(http.StatusNotFound) - return - } - - // Index creation: close connection to trigger error - if r.Method == http.MethodPut && r.URL.Path == "/test" { - hj, ok := w.(http.Hijacker) - if ok { - conn, _, _ := hj.Hijack() - conn.Close() - } - return - } - - w.WriteHeader(http.StatusOK) - })) - t.Cleanup(server.Close) - - cfg := &Config{ - Address: server.URL, - } - - logger := logging.NewNoopLogger() - tracerProvider := tracing.NewNoopTracerProvider() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) - assert.Error(t, err) - assert.Nil(t, im) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_Index_Unit(T *testing.T) { - T.Parallel() - - T.Run("with successful index", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusCreated) - _, _ = fmt.Fprint(w, `{"_index":"test","_id":"123","result":"created"}`) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) - assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with non-success status code", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusBadRequest) - _, _ = fmt.Fprint(w, `{"error":{"type":"mapper_parsing_exception","reason":"failed to parse"}}`) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) - assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_Search_Unit(T *testing.T) { - T.Parallel() - - T.Run("with successful search results", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, `{"hits":{"total":{"value":1},"hits":[{"_id":"123","_source":{"id":"123","name":"test"}}]}}`) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - require.Len(t, results, 1) - assert.Equal(t, "123", results[0].ID) - assert.Equal(t, "test", results[0].Name) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with error response", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusBadRequest) - _, _ = fmt.Fprint(w, `{"error":{"type":"search_phase_execution_exception","reason":"all shards failed"}}`) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - // NOTE: the search function has a named return 'err' that is overwritten - // by the deferred res.Body.Close() call, so the error is lost. The code - // does exercise the IsError() branch and calls circuitBreaker.Failed(), - // but ultimately returns nil error due to the defer clobbering it. - results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) - }) - - T.Run("with invalid JSON in success response", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, `not valid json`) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - // NOTE: same issue as error response test - the deferred res.Body.Close() - // overwrites the named return 'err' with nil. - results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_Search_ErrorResponseDecodeFailure_Unit(T *testing.T) { - T.Parallel() - - T.Run("with invalid JSON in error response body", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusBadRequest) - _, _ = fmt.Fprint(w, `this is not valid json`) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - // NOTE: the named return 'err' from the deferred res.Body.Close() clobbers - // the error, so this returns nil error despite the decode failure. - results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_Search_SourceUnmarshalError_Unit(T *testing.T) { - T.Parallel() - - T.Run("with invalid source in hit", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, `{"hits":{"total":{"value":1},"hits":[{"_id":"123","_source":"not a valid object"}]}}`) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - // NOTE: the named return 'err' from the deferred res.Body.Close() clobbers - // the error, so this returns nil error despite the unmarshal failure. - results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) - }) -} - -func TestIndexManager_Delete_Unit(T *testing.T) { - T.Parallel() - - T.Run("with successful delete", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Elastic-Product", "Elasticsearch") - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprint(w, `{"_index":"test","_id":"123","result":"deleted"}`) - })) - t.Cleanup(server.Close) - - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() - - im := buildTestIndexManagerWithServer(t, server, cb) - - err := im.Delete(context.Background(), "123") - assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) - }) -} From 2251200d8344aabd3e8dfa79b121713dfce71700 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Sat, 11 Apr 2026 11:22:58 -0500 Subject: [PATCH 06/12] test: replace stretchr/testify/mock with matryer/moq --- .claude/skills/convert-assertions/SKILL.md | 208 ++++++ .claude/skills/convert-mocks/SKILL.md | 285 ++++++++ analytics/config/config_test.go | 18 +- analytics/mock/doc.go | 8 + analytics/mock/event_reporter_mock.go | 250 +++++++ analytics/mock/mock.go | 38 - analytics/multisource/reporter_test.go | 31 +- analytics/posthog/posthog_test.go | 29 +- analytics/rudderstack/rudderstack_test.go | 29 +- analytics/segment/segment_test.go | 29 +- cache/config/config_test.go | 4 +- cache/redis/mocks_gen_test.go | 4 +- cache/redis/redis_test.go | 14 +- capitalism/mock/doc.go | 9 + capitalism/mock/mock_payment_manager.go | 26 - capitalism/mock/payment_manager_mock.go | 77 +++ capitalism/stripe/mock_backend_test.go | 77 --- capitalism/stripe/stripe_test.go | 21 +- circuitbreaking/config/config_test.go | 75 +- .../{mock2 => mock}/circuitbreaker_mock.go | 2 +- circuitbreaking/mock/doc.go | 9 + circuitbreaking/mock/mock.go | 23 - circuitbreaking/mock2/doc.go | 10 - cryptography/encryption/mock/doc.go | 9 +- .../mock/encryptor_decryptor_mock.go | 133 ++++ cryptography/encryption/mock/mock.go | 28 - database/database_mock.go | 155 ----- database/database_mock_test.go | 322 --------- database/mock/database_mock.go | 650 ++++++++++++++++++ database/mock/doc.go | 8 + database/mysql/mysql_test.go | 13 +- database/postgres/postgres_test.go | 13 +- database/sqlite/sqlite_test.go | 13 +- distributedlock/config/config_test.go | 13 +- distributedlock/mock/doc.go | 9 + distributedlock/mock/locker.go | 68 -- distributedlock/mock/locker_mock.go | 361 ++++++++++ distributedlock/postgres/postgres_test.go | 49 +- distributedlock/redis/redis_test.go | 88 ++- email/config/config_test.go | 13 +- email/mock/doc.go | 9 + email/mock/emailer_mock.go | 83 +++ email/mock/mock_emailer.go | 23 - embeddings/mock/doc.go | 9 + embeddings/mock/embedder_mock.go | 83 +++ embeddings/mock/mock.go | 25 - embeddings/mock/mock_test.go | 51 -- encoding/client_encoder_test.go | 9 +- encoding/mock/doc.go | 7 +- encoding/mock/encoder_decoder_mock.go | 530 ++++++++++++++ encoding/mock/mock_client_encoder.go | 35 - encoding/mock/mock_encoding.go | 83 --- encoding/mock_io_writer_test.go | 9 +- featureflags/config/config_test.go | 13 +- .../launchdarkly/feature_flag_manager_test.go | 70 +- featureflags/mock/doc.go | 9 + .../mock/feature_flag_manager_mock.go | 374 ++++++++++ .../mock/mock_feature_flag_manager.go | 50 -- .../posthog/feature_flag_manager_test.go | 25 +- llm/anthropic/anthropic_test.go | 45 +- llm/config/config_test.go | 20 +- llm/mock/doc.go | 9 + llm/mock/mock.go | 25 - llm/mock/mock_test.go | 33 - llm/mock/provider_mock.go | 83 +++ llm/openai/openai_test.go | 45 +- messagequeue/kafka/consumer_test.go | 144 ++-- messagequeue/kafka/publisher_test.go | 95 ++- messagequeue/mock/consumer.go | 35 - messagequeue/mock/doc.go | 10 + messagequeue/mock/messagequeue_mock.go | 479 +++++++++++++ messagequeue/mock/mock.go | 53 -- messagequeue/pubsub/consumer_test.go | 12 +- messagequeue/pubsub/publisher_test.go | 47 +- messagequeue/redis/consumer_test.go | 12 +- messagequeue/redis/publisher_test.go | 97 +-- messagequeue/sqs/consumer_test.go | 141 ++-- messagequeue/sqs/publisher_test.go | 100 +-- notifications/mobile/apns/apns_sender_test.go | 33 +- notifications/mobile/fcm/fcm_sender_test.go | 33 +- observability/metrics/mock/doc.go | 6 +- observability/metrics/mock/int64_counter.go | 20 - observability/metrics/mock/provider.go | 74 -- .../metrics/{mock2 => mock}/provider_mock.go | 81 ++- observability/metrics/mock2/doc.go | 10 - observability/metrics/mock_provider.go | 72 -- observability/metrics/noop.go | 5 +- panicking/mock/doc.go | 9 + panicking/mock/mock.go | 25 - panicking/mock/panicker_mock.go | 126 ++++ random/mock/doc.go | 9 + random/mock/generator_mock.go | 233 +++++++ random/mock/mock_random.go | 42 -- ratelimiting/redis/redis_test.go | 89 ++- routing/mock/doc.go | 9 + routing/mock/route_param_manager.go | 34 - routing/mock/route_param_manager_mock.go | 134 ++++ search/text/algolia/index_test.go | 131 ++-- search/text/config/config_test.go | 13 +- .../text/elasticsearch/elasticsearch_test.go | 62 +- search/text/elasticsearch/index_test.go | 135 ++-- search/text/indexing/do.go | 2 +- search/text/indexing/do_test.go | 26 +- search/text/indexing/indexer_test.go | 348 +++++++--- search/text/mock/doc.go | 6 +- search/text/mock/index_manager.go | 38 - search/text/mock/index_manager_test.go | 78 --- search/text/mock/index_mock.go | 233 +++++++ search/vector/config/config_test.go | 13 +- search/vector/mock/doc.go | 6 + search/vector/mock/index.go | 40 -- search/vector/mock/index_mock.go | 227 ++++++ search/vector/pgvector/pgvector_test.go | 101 +-- search/vector/qdrant/qdrant_test.go | 29 +- secrets/config/config_test.go | 47 +- secrets/env/env_test.go | 27 +- secrets/gcp/gcp_test.go | 45 +- secrets/kubectl/kubectl_test.go | 45 +- secrets/ssm/ssm_test.go | 45 +- server/grpc/server_test.go | 17 +- server/http/http_server_test.go | 17 +- testutils/mock_handler.go | 19 - testutils/mock_http_response_writer.go | 31 - testutils/mock_io_read_closer.go | 26 - testutils/mock_io_writer.go | 20 - testutils/mock_matchers.go | 24 - uploads/images/images_test.go | 20 +- uploads/images/mock.go | 28 - uploads/mock/doc.go | 9 + uploads/mock/mock.go | 28 - uploads/mock/upload_manager_mock.go | 139 ++++ uploads/objectstorage/files_test.go | 24 +- uploads/objectstorage/mock_uploader.go | 30 - uploads/objectstorage/mock_uploader_test.go | 41 -- 134 files changed, 6528 insertions(+), 2889 deletions(-) create mode 100644 .claude/skills/convert-assertions/SKILL.md create mode 100644 .claude/skills/convert-mocks/SKILL.md create mode 100644 analytics/mock/doc.go create mode 100644 analytics/mock/event_reporter_mock.go delete mode 100644 analytics/mock/mock.go create mode 100644 capitalism/mock/doc.go delete mode 100644 capitalism/mock/mock_payment_manager.go create mode 100644 capitalism/mock/payment_manager_mock.go delete mode 100644 capitalism/stripe/mock_backend_test.go rename circuitbreaking/{mock2 => mock}/circuitbreaker_mock.go (99%) create mode 100644 circuitbreaking/mock/doc.go delete mode 100644 circuitbreaking/mock/mock.go delete mode 100644 circuitbreaking/mock2/doc.go create mode 100644 cryptography/encryption/mock/encryptor_decryptor_mock.go delete mode 100644 cryptography/encryption/mock/mock.go delete mode 100644 database/database_mock.go delete mode 100644 database/database_mock_test.go create mode 100644 database/mock/database_mock.go create mode 100644 database/mock/doc.go create mode 100644 distributedlock/mock/doc.go delete mode 100644 distributedlock/mock/locker.go create mode 100644 distributedlock/mock/locker_mock.go create mode 100644 email/mock/doc.go create mode 100644 email/mock/emailer_mock.go delete mode 100644 email/mock/mock_emailer.go create mode 100644 embeddings/mock/doc.go create mode 100644 embeddings/mock/embedder_mock.go delete mode 100644 embeddings/mock/mock.go delete mode 100644 embeddings/mock/mock_test.go create mode 100644 encoding/mock/encoder_decoder_mock.go delete mode 100644 encoding/mock/mock_client_encoder.go delete mode 100644 encoding/mock/mock_encoding.go create mode 100644 featureflags/mock/doc.go create mode 100644 featureflags/mock/feature_flag_manager_mock.go delete mode 100644 featureflags/mock/mock_feature_flag_manager.go create mode 100644 llm/mock/doc.go delete mode 100644 llm/mock/mock.go delete mode 100644 llm/mock/mock_test.go create mode 100644 llm/mock/provider_mock.go delete mode 100644 messagequeue/mock/consumer.go create mode 100644 messagequeue/mock/doc.go create mode 100644 messagequeue/mock/messagequeue_mock.go delete mode 100644 messagequeue/mock/mock.go delete mode 100644 observability/metrics/mock/int64_counter.go delete mode 100644 observability/metrics/mock/provider.go rename observability/metrics/{mock2 => mock}/provider_mock.go (89%) delete mode 100644 observability/metrics/mock2/doc.go delete mode 100644 observability/metrics/mock_provider.go create mode 100644 panicking/mock/doc.go delete mode 100644 panicking/mock/mock.go create mode 100644 panicking/mock/panicker_mock.go create mode 100644 random/mock/doc.go create mode 100644 random/mock/generator_mock.go delete mode 100644 random/mock/mock_random.go create mode 100644 routing/mock/doc.go delete mode 100644 routing/mock/route_param_manager.go create mode 100644 routing/mock/route_param_manager_mock.go delete mode 100644 search/text/mock/index_manager.go delete mode 100644 search/text/mock/index_manager_test.go create mode 100644 search/text/mock/index_mock.go create mode 100644 search/vector/mock/doc.go delete mode 100644 search/vector/mock/index.go create mode 100644 search/vector/mock/index_mock.go delete mode 100644 testutils/mock_handler.go delete mode 100644 testutils/mock_http_response_writer.go delete mode 100644 testutils/mock_io_read_closer.go delete mode 100644 testutils/mock_io_writer.go delete mode 100644 testutils/mock_matchers.go delete mode 100644 uploads/images/mock.go create mode 100644 uploads/mock/doc.go delete mode 100644 uploads/mock/mock.go create mode 100644 uploads/mock/upload_manager_mock.go delete mode 100644 uploads/objectstorage/mock_uploader.go delete mode 100644 uploads/objectstorage/mock_uploader_test.go diff --git a/.claude/skills/convert-assertions/SKILL.md b/.claude/skills/convert-assertions/SKILL.md new file mode 100644 index 0000000..9cbb0fd --- /dev/null +++ b/.claude/skills/convert-assertions/SKILL.md @@ -0,0 +1,208 @@ +--- +name: convert-assertions +description: > + Mechanical rewrite of testify/{assert,require} to shoenig/{test,must} in a + single Go package of this repo. Does NOT touch mocks, file layout, or any + other package. Invoke as /convert-assertions , e.g. + /convert-assertions ./ratelimiting/redis/. Paired with /convert-mocks which + handles the testify/mock → matryer/moq half of the migration independently. +--- + +# convert-assertions + +Migrate a single Go package's test assertions from `testify/{assert,require}` to `github.com/shoenig/test` (non-fatal, package name `test`) + `github.com/shoenig/test/must` (fatal). This is Phase 1 of the shoenig/moq migration, executed one package at a time. Mocks are out of scope — see `/convert-mocks` for that half. + +## Argument + +The user invokes this with a package path relative to the repo root (e.g. `./ratelimiting/redis/`). That's the target. Work only inside it. + +## Hard rules — do not violate + +1. **Do NOT touch mocks.** Leave every `stretchr/testify/mock` import, `.On(...)`, `.Return(...)`, `mock.Anything`, `mock.MatchedBy(...)`, `mock.AssertExpectationsForObjects(...)`, and `/mock.Mock{}` call site exactly as-is. If a test file mixes testify mocks and testify assertions, this skill converts only the assertions; the file becomes a hybrid that still compiles. `/convert-mocks` will clean up the mock half later. +2. **Do NOT touch other packages.** Only edit `*_test.go` files inside the target package. +3. **Do NOT move or split tests.** File layout cleanup is a separate concern; it's not this skill's job. +4. **Do NOT rename tests or subtests.** No changing `Test_Foo_Unit` → `Test_Foo`, no dropping `_Unit` suffixes. Preserve existing naming. +5. **Do NOT use sed/awk/grep-replace.** Use the `Edit` tool with exact strings. The API has traps (length-first argument order on `SliceLen`/`MapLen`, `Eq` vs `EqOp` distinction) that a dumb find/replace will get wrong. +6. **Do NOT compile or leave binaries.** Verify via `go vet` and `go test`. No `go build -o` without immediate cleanup. +7. **Preserve `T` vs `t`.** Repo convention: top-level test functions take `T *testing.T`, subtests take `t *testing.T`. Don't touch this. See `feedback_test_variable_naming.md` in user memory. +8. **One package per invocation.** Never go beyond the target, even if you notice neighboring packages would benefit. + +## Preflight + +Before editing anything: + +1. Read `CLAUDE.md` at the repo root if you haven't this session — it has the module path and import ordering rules. +2. Read `project_moq_shoenig_pilot.md` in user memory — it captures the per-quirk details about the pilot. +3. `ls /` to know what test files exist. +4. `Grep` for `stretchr/testify/assert` and `stretchr/testify/require` within the target to enumerate files that need rewriting. If neither appears, report back immediately — nothing to do. +5. Baseline: `go build .//... && CGO_ENABLED=1 go test -race -vet=all -shuffle=on .//... 2>&1`. It must pass BEFORE any edit. If the baseline is broken, stop and tell the user — the breakage isn't yours to fix under this skill. + +## Import rewrites + +| Testify import | Shoenig replacement | +|---|---| +| `"github.com/stretchr/testify/assert"` | `"github.com/shoenig/test"` (package name: `test`) | +| `"github.com/stretchr/testify/require"` | `"github.com/shoenig/test/must"` (package name: `must`) | + +**There is no `github.com/shoenig/test/test` subpackage.** The non-fatal package lives at the module root with package name `test`. If you ever write the import path `github.com/shoenig/test/test`, you made a mistake — `go build` will tell you immediately. The correct form is: + +```go +import ( + "github.com/shoenig/test" // for test.Eq, test.NoError, etc (non-fatal) + "github.com/shoenig/test/must" // for must.NoError, must.NotNil, etc (fatal) +) +``` + +If a file used only `assert`, it gets only the root import. If a file used only `require`, it gets only `/must`. If it used both, it gets both imports. + +## Call-site rewrites + +`assert.*` → `test.*` (non-fatal). `require.*` → `must.*` (fatal). The function names match between `test` and `must` — so `assert.NoError` → `test.NoError`, `require.NoError` → `must.NoError`. + +### Direct one-to-one mappings + +| Testify | Shoenig | Notes | +|---|---|---| +| `assert.NoError(t, err)` | `test.NoError(t, err)` | | +| `require.NoError(t, err)` | `must.NoError(t, err)` | | +| `assert.Error(t, err)` | `test.Error(t, err)` | | +| `require.Error(t, err)` | `must.Error(t, err)` | | +| `assert.ErrorIs(t, err, target)` | `test.ErrorIs(t, err, target)` | | +| `assert.EqualError(t, err, "msg")` | `test.EqError(t, err, "msg")` | name change: `EqualError` → `EqError` | +| `assert.Nil(t, x)` | `test.Nil(t, x)` | | +| `assert.NotNil(t, x)` | `test.NotNil(t, x)` | | +| `assert.True(t, x)` | `test.True(t, x)` | | +| `assert.False(t, x)` | `test.False(t, x)` | | + +### Equality — `Eq` vs `EqOp` + +`assert.Equal` has two shoenig equivalents and choosing the right one matters: + +- **`test.EqOp(t, want, got)`** — uses Go `==`. Only works for **comparable** types: strings, numbers, bools, pointers, `time.Duration`, channels, and structs whose fields are all comparable. Faster, catches type mismatches at compile time. Prefer this when possible. +- **`test.Eq(t, want, got)`** — uses reflect/go-cmp. Works on slices, maps, structs with non-comparable fields, pointers-to-structs where you want deep comparison. + +**Rules of thumb:** +- Primitives (`string`, `int`, `bool`, `time.Duration`) → `EqOp` +- Named comparable types (e.g., custom string types, enum-like consts) → `EqOp` +- Slices or maps → `Eq` (`EqOp` won't compile) +- Structs → usually `Eq` (safer); use `EqOp` only if the struct is explicitly comparable and small +- Pointers-to-structs — you almost always want the underlying struct value compared → `test.Eq(t, want, got)` works (go-cmp handles it) +- Errors — prefer `test.ErrorIs(t, err, target)` or `test.EqError(t, err, "msg string")` over `Eq`/`EqOp` on error values directly + +If unsure, use `Eq`. If lint/tests complain with `EqOp` (type is not comparable), switch to `Eq`. + +### Length and collection checks — argument order is FLIPPED + +⚠️ **This is the single biggest trap in the migration.** In testify, length comes AFTER the collection. In shoenig, length comes FIRST. + +| Testify | Shoenig | +|---|---| +| `assert.Len(t, slice, 3)` | `test.SliceLen(t, 3, slice)` | +| `assert.Len(t, m, 3)` (where `m` is a map) | `test.MapLen(t, 3, m)` | +| `assert.Empty(t, slice)` | `test.SliceEmpty(t, slice)` | +| `assert.Empty(t, m)` (map) | `test.MapEmpty(t, m)` | +| `assert.Contains(t, slice, elem)` | `test.SliceContains(t, slice, elem)` | +| `assert.Contains(t, str, substr)` | `test.StrContains(t, str, substr)` | +| `assert.Contains(t, m, key)` (map) | `test.MapContainsKey(t, m, key)` | +| `assert.NotContains(t, slice, elem)` | `test.SliceNotContains(t, slice, elem)` | + +**There is no polymorphic `Len` or `Contains` or `Empty`.** You must pick the right variant based on the collection type. When the existing testify call is `assert.Len(t, x, n)` and you can't immediately tell whether `x` is a slice or a map, read the surrounding code: check where `x` is declared or assigned. + +### Less common but worth knowing + +| Testify | Shoenig | +|---|---| +| `assert.NotEqual(t, a, b)` | `test.NotEq(t, a, b)` | +| `assert.Same(t, a, b)` (same pointer) | `test.Eq(t, a, b)` works; no distinct `Same` in shoenig | +| `assert.IsType(t, expected, actual)` | `test.EqOp(t, reflect.TypeOf(expected), reflect.TypeOf(actual))` or use a type assertion + `test.True` | +| `assert.Panics(t, f)` | shoenig does not have a direct equivalent — rare; if encountered, flag in the report and leave the testify call alone, or implement with `defer recover()` inline | +| `assert.Fail(t, "msg")` | `t.Fatalf("msg")` or `t.Errorf("msg")` — use stdlib | +| `assert.FailNow(t, "msg")` | `t.Fatalf("msg")` — use stdlib | +| `assert.Greater(t, a, b)` | shoenig's `test.Greater(t, b, a)` — verify argument order in shoenig docs before committing; if uncertain, use `test.True(t, a > b)` | +| `assert.Less(t, a, b)` | same caveat — use `test.True(t, a < b)` if uncertain | + +If you encounter a testify assertion not listed here, do ONE of: +1. Use `test.True(t, condition)` with the condition inlined. +2. Flag it in the report and leave the testify call untouched so the user can decide. + +Do not guess a name you're not confident about. + +## Import ordering (gci with --custom-order) + +After editing, the final import block must satisfy gci's custom-order rules: + +1. Standard library +2. `prefix(github.com/verygoodsoftwarenotvirus/platform)` — this module +3. `prefix(github.com/verygoodsoftwarenotvirus)` — org-level (usually empty for this repo) +4. default — third-party + +shoenig/test imports go in section 4 (third-party), separated from the platform imports by a blank line. Example: + +```go +import ( + "context" + "testing" + + "github.com/verygoodsoftwarenotvirus/platform/v5/some/package" + + "github.com/shoenig/test" + "github.com/shoenig/test/must" +) +``` + +The `Edit` tool doesn't auto-reorder imports. If you're moving imports between sections, do the edit explicitly and then run the verification step below — it includes `gci diff` which will catch ordering mistakes. + +## Verification + +Every check must pass before you declare success. Run them in this order: + +```bash +go build .//... 2>&1 +go vet .//... 2>&1 +CGO_ENABLED=1 go test -race -vet=all -shuffle=on .//... 2>&1 +gofmt -l / +go tool github.com/daixiang0/gci diff --skip-generated --custom-order \ + --section standard \ + --section "prefix(github.com/verygoodsoftwarenotvirus/platform)" \ + --section "prefix(github.com/verygoodsoftwarenotvirus)" \ + --section default \ + / +``` + +If Docker is available and you have time, also lint: + +```bash +docker run --rm --volume "$PWD:$PWD" --workdir="$PWD" \ + golangci/golangci-lint:v2.10.1 \ + golangci-lint run --timeout 5m .//... +``` + +Must report `0 issues.` If anything fails, **fix it before reporting success**. Do not use `//nolint` comments as a workaround. + +## Pitfalls that burned the pilot + +- **`SliceLen`/`MapLen` argument order is flipped from testify.** Always `(t, n, collection)`. After the rewrite pass, re-scan the file for these calls and double-check. +- **`Eq` vs `EqOp`.** `EqOp` on a slice/map is a compile error. Switch to `Eq` and try again. +- **`github.com/shoenig/test/test` does not exist.** The non-fatal package is the module root. +- **LSP diagnostics lag.** If the IDE reports errors that don't match your recent edits, trust `go build` instead — if build exits 0, the LSP is stale. +- **Don't touch `T` / `t` conventions.** Top-level test functions use `T *testing.T`; subtests use `t *testing.T`. Never rename. + +## Report format + +Post a concise summary (under 30 lines): + +1. **Files modified** — list the `*_test.go` files touched. +2. **Counts** — approximate: `assert.* → test.*: N`, `require.* → must.*: M`. +3. **Any tricky calls** — list any assertions where you picked `Eq` over `EqOp` for a non-obvious reason, any `Panics`/`Greater`/`Less` calls you left alone, any calls you had to replace with stdlib. +4. **What WASN'T touched** — confirm mocks were left alone (testify mock imports still present if any were before). +5. **Verification results** — ✓ on each check in the list, or what failed. +6. **Git status summary** — short list of modified files (no diff). + +Keep it tight. The user will invoke this across many packages and wants each report readable at a glance. + +## Edge cases + +- **Target has no testify assertions** — report "nothing to do, package already converted (or never used testify)" and exit. +- **Target has testify assertions AND testify mocks** — convert assertions only, leave mocks. The file becomes a hybrid. Explicitly mention this in the report so the user knows `/convert-mocks` is still needed. +- **Target has a shared helper file** (e.g., `helpers_test.go`) that uses testify assertions and is imported by tests in sub-packages — convert it, but flag in the report that it's a shared helper so the user can check sub-package tests compile. +- **A call site is ambiguous about whether `x` is a slice or a map** — read more context (declaration, surrounding ops like `append` or `[key]` indexing) to determine. If still unclear, flag it in the report with the file/line and leave the call untouched. diff --git a/.claude/skills/convert-mocks/SKILL.md b/.claude/skills/convert-mocks/SKILL.md new file mode 100644 index 0000000..8ada821 --- /dev/null +++ b/.claude/skills/convert-mocks/SKILL.md @@ -0,0 +1,285 @@ +--- +name: convert-mocks +description: > + Convert testify mock usage to matryer/moq in a single Go package of this + repo. Generates new moq files in shared mock/ packages as needed (additively + — never deletes testify types). Paired with /convert-assertions; either can + run first. Invoke as /convert-mocks , e.g. + /convert-mocks ./ratelimiting/redis/. +--- + +# convert-mocks + +Migrate a single Go package's tests from `stretchr/testify/mock` (and hand-written testify-based mock packages) to `github.com/matryer/moq`. This is Phase 2 of the shoenig/moq migration, executed one package at a time. Assertions are out of scope — see `/convert-assertions` for that half. Either skill can run before the other; the file becomes a hybrid during a partial migration and that's fine. + +## Argument + +The user invokes this with a package path relative to the repo root (e.g. `./ratelimiting/redis/`). That's the target. Work only inside it, with one permitted exception: you may ADD (never delete or modify) files inside a shared `/mock/` sub-package to generate a moq mock the target depends on. + +## Hard rules — do not violate + +1. **Do NOT delete testify mock types from shared mock packages.** ~30+ other consumers may still rely on them. The testify types die in the final PR of the migration, not in this one. Every `/mock/` package during the transition holds BOTH the hand-written testify types and the moq-generated types side by side. +2. **Do NOT modify existing hand-written testify mock files** (e.g., `circuitbreaking/mock/mock.go`, `observability/metrics/mock/provider.go`). Only ADD new `*_mock.go` files or update `doc.go` to include a new `//go:generate` directive. +3. **Do NOT introduce `mock2` naming.** Phase 0 of the migration killed that experiment. moq files live directly in the canonical `mock/` packages alongside testify files, with distinct type names (`CircuitBreakerMock` vs `MockCircuitBreaker`, `ProviderMock` vs `MetricsProvider`). +4. **Do NOT touch assertions.** Leave every `assert.*`, `require.*`, `test.*`, and `must.*` call exactly as you found it. If you need to add new assertions (e.g., verifying call counts after dropping `AssertExpectationsForObjects`), write them in shoenig form (`test.SliceLen(t, n, mock.XCalls())`) — shoenig is the repo's target state. If the file doesn't already import `github.com/shoenig/test`, add the import. +5. **Do NOT export internal test-seam interfaces.** If the target has an unexported interface like `redisClient` used only for mock injection in tests, keep its moq mock INLINE in the target package (a `*_test.go` file with `//go:generate` + `-skip-ensure` + alias syntax). Don't move it into a sibling `mock/` package. +6. **Do NOT move or split tests.** File layout cleanup is a separate concern. +7. **Do NOT compile or leave binaries.** Use `go vet` / `go test` / `go build` without `-o`. If you do produce a binary for any reason, delete it immediately. +8. **Do NOT add empty method stubs** to satisfy moq generation. See `feedback_no_empty_methods.md` in user memory. If moq generates `XxxFunc` fields the test doesn't need, leave them nil — calling a nil `XxxFunc` panics, which is the correct failure mode for "this method must not be called." +9. **One package per invocation.** Never convert multiple packages in one run. + +## Preflight + +Before editing anything: + +1. Read `CLAUDE.md` at the repo root if you haven't this session. +2. Read `project_moq_shoenig_pilot.md` in user memory — critical for the naming quirks (e.g., `mockmetrics` package name being non-standard). +3. `ls /` to know what test files exist. +4. `Grep` for `stretchr/testify/mock`, `testify/mock`, `.On(`, `.Return(`, `mock.Anything`, `AssertExpectations`, and any `.Mock` struct-literal pattern within the target to enumerate what needs rewriting. +5. `Grep` for imports of shared mock packages (e.g., `circuitbreaking/mock`, `observability/metrics/mock`, `encoding/mock`, `messagequeue/mock`, `routing/mock`, `uploads/mock`, `uploads/images/mock`) to know which shared mocks the target relies on. +6. Baseline: `go build .//... && CGO_ENABLED=1 go test -race -vet=all -shuffle=on .//... 2>&1`. Must pass before any edit. If broken, stop and tell the user. + +If no testify mock usage surfaces in preflight, report "nothing to do" and exit. + +## Step 1: inventory the mocks the target uses + +Categorize every mock the target references into one of four buckets. The handling differs per bucket. + +**Bucket A: moq version already exists in the shared mock package.** No generation needed; just flip call sites. + +Known-good as of the pilot: + +| Shared package | Testify type | Moq type | Import | +|---|---|---|---| +| `circuitbreaking/mock` | `MockCircuitBreaker` | `CircuitBreakerMock` | `mockcircuitbreaking "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock"` (alias required — package name is `mock`) | +| `observability/metrics/mock` | `MetricsProvider` | `ProviderMock` | `"github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock"` (no alias — package name is already `mockmetrics`) | + +To confirm a shared mock package has a moq version: look for a `*_mock.go` file (not `_test.go`) in the package, or grep for `//go:generate go tool github.com/matryer/moq` in its `doc.go`. + +**Bucket B: moq version does not exist yet in the shared mock package, must be generated.** This is the "additive" workflow: + +1. Read the interface declaration in the parent package (e.g., read `encoding/encoding.go` to find the interface definition for a mock needed from `encoding/mock`). +2. Open the shared mock package's `doc.go` — create one if it doesn't exist. +3. Add (or append) a `//go:generate` directive: + ```go + //go:generate go tool github.com/matryer/moq -out _mock.go -pkg -rm -fmt goimports .. :Mock + ``` + - `` must match the existing `package` declaration in the mock package's files. For `circuitbreaking/mock` it's `mock`; for `observability/metrics/mock` it's `mockmetrics`; for others, check. + - The source-dir `..` assumes the interface lives in the parent package of `mock/`. If the interface is elsewhere, adjust the relative path. + - The alias `InterfaceName:InterfaceNameMock` gives the mock a suffix-style name that won't collide with the testify prefix-style names (e.g., `MockCircuitBreaker` vs `CircuitBreakerMock`). +4. Run `go generate .//mock/` to produce the file. +5. Verify the generated file compiles: it should contain a `var _ . = &Mock{}` ensure-line. + +**Bucket C: the target defines its own local testify mock inline** (e.g., a `mockRedisClient struct { mock.Mock }` in a `*_test.go` file). This is the unexported-interface-inline case: + +1. Identify the interface being mocked (in the target's source, not a test file). +2. If the interface is UNEXPORTED, keep the mock inline. Add a `mocks_gen_test.go` file (or update an existing one) with a `//go:generate` directive pointing at the current directory and using `-skip-ensure` + alias if the interface is unexported: + ```go + //go:generate go tool github.com/matryer/moq -out _mock_test.go -pkg -rm -skip-ensure -fmt goimports . :Mock + ``` + Use `-skip-ensure` only for unexported interfaces. For exported interfaces in the same package, drop `-skip-ensure`. +3. Run `go generate .//`. +4. Delete the old hand-written `mockFoo struct { mock.Mock }` definition and any methods on it. + +**Bucket D: the target uses a generic `stretchr/testify/mock.Mock` directly without a wrapper** — rare. Usually a sign someone cut a corner. Treat as Bucket C: identify what it's pretending to be, define a proper interface if needed, generate a moq. + +## Step 2: rewrite call sites + +For each testify mock usage in the target's test files: + +### Struct instantiation + +| Testify | moq | +|---|---| +| `m := &mockcircuitbreaking.MockCircuitBreaker{}` | `m := &mockcircuitbreaking.CircuitBreakerMock{}` | +| `m := &mockmetrics.MetricsProvider{}` | `m := &mockmetrics.ProviderMock{}` | +| `m := &mockRedisClient{}` (local) | `m := &redisClientMock{}` (whatever moq's alias produced) | + +### Method expectations — `.On(...).Return(...)` → `XxxFunc` closures + +The single biggest conceptual shift. Testify sets up expectations at construction time; moq sets a function field that runs whenever the method is called. + +**Simple return value:** +```go +// testify +m.On("CannotProceed").Return(false) + +// moq +m.CannotProceedFunc = func() bool { return false } +``` + +**Void return:** +```go +// testify +m.On("Succeeded").Return() + +// moq +m.SucceededFunc = func() {} +``` + +**Return with computed value based on args:** +```go +// testify (using argument matchers) +client.On("Get", mock.Anything, "expected-key").Return(someResult, nil) + +// moq +client.GetFunc = func(_ context.Context, key string) (ResultType, error) { + test.EqOp(t, "expected-key", key) + return someResult, nil +} +``` + +**Different returns for different inputs — prefer a dispatch map:** +```go +// testify +m.On("NewInt64Counter", "name_hits", /* opts */).Return(okCounter, nil) +m.On("NewInt64Counter", "name_misses", /* opts */).Return(nil, errors.New("fail")) + +// moq — declare the dispatch map, then the closure reads from it +results := map[string]struct{ counter metrics.Int64Counter; err error }{ + "name_hits": {counter: okCounter}, + "name_misses": {counter: nil, err: errors.New("fail")}, +} +m.NewInt64CounterFunc = func(metricName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + res, ok := results[metricName] + if !ok { t.Fatalf("unexpected NewInt64Counter call: %q", metricName) } + return res.counter, res.err +} +``` + +See `cache/redis/redis_test.go`'s `newCounterProviderMock` helper for this pattern in action. + +**Sequential returns for the same args (rare):** +```go +// testify +m.On("X").Return(first).Once() +m.On("X").Return(second).Once() + +// moq — stateful closure with a counter +var calls int +m.XFunc = func() ResultType { + calls++ + if calls == 1 { return first } + return second +} +``` + +### Argument matchers + +| Testify | moq | +|---|---| +| `mock.Anything` | `_` in closure param list (ignored) | +| `mock.MatchedBy(func(x X) bool { return predicate })` | explicit check inside closure body | +| `testutils.ContextMatcher` | `_ context.Context` in closure (ignored) | +| `testutils.QueryFilterMatcher` | `_ *filtering.QueryFilter` in closure (ignored) | +| any literal value (e.g., `"foo"`, `42`) | explicit check inside closure: `test.EqOp(t, "foo", got)` | + +The closure IS the matcher. There is no matcher library. + +### Expectation verification — `AssertExpectations*` and `.Times()` + +| Testify | moq equivalent | +|---|---| +| `mock.AssertExpectationsForObjects(t, m)` | usually drop. The functional check (did the code return the right result?) already proves the mock was invoked. Only add explicit call-count checks if the test was specifically verifying invocation count. | +| `m.AssertCalled(t, "Method", args)` | `test.SliceLen(t, ≥1, m.MethodCalls())` and optionally inspect `m.MethodCalls()[i]` for args | +| `m.AssertNotCalled(t, "Method")` | `test.SliceLen(t, 0, m.MethodCalls())` | +| `.Times(n)` set at expectation time | `test.SliceLen(t, n, m.MethodCalls())` at end of test | +| `.Once()` set at expectation time | `test.SliceLen(t, 1, m.MethodCalls())` at end of test | + +`m.XxxCalls()` returns a slice of typed structs, one element per call, with fields matching the method's parameters (minus variadics which get a slice field). You can index into it to assert per-call arguments if needed. + +If the new call-count assertion requires `github.com/shoenig/test`, add the import. Don't be shy about it — shoenig is the repo's target state even for the mock-migrated half. + +## Step 3: imports + +After editing, the file's imports may need adjustment: + +- **Remove** `"github.com/stretchr/testify/mock"` if there are no more `mock.Anything`, `mock.AssertExpectations*`, or similar references. +- **Remove** the old shared mock import alias (`mockcircuitbreaking "..."`) only if you're flipping every usage in the file; if some references still reach for `MockCircuitBreaker` (e.g., a shared helper in the same file is still testify-bound), keep the import. +- **Add** `"github.com/shoenig/test"` (and/or `/must`) if you introduced new shoenig assertions for call-count verification. + +Import ordering follows `gci --custom-order`: + +1. std +2. `prefix(github.com/verygoodsoftwarenotvirus/platform)` — blank-line separator after +3. `prefix(github.com/verygoodsoftwarenotvirus)` — org-level (usually empty) +4. default (third-party) + +## Verification + +Every check must pass. Run in this order: + +```bash +go build .//... 2>&1 +go vet .//... 2>&1 +CGO_ENABLED=1 go test -race -vet=all -shuffle=on .//... 2>&1 +gofmt -l / +``` + +If you touched a shared mock package (Bucket B) or generated inline moqs (Bucket C), also verify generation is idempotent: + +```bash +make generate 2>&1 +git status --porcelain +``` + +The diff after `make generate` should match what you intended and include NO surprise changes to other packages. Run `make generate` a second time and confirm no new drift. + +gci format check: + +```bash +go tool github.com/daixiang0/gci diff --skip-generated --custom-order \ + --section standard \ + --section "prefix(github.com/verygoodsoftwarenotvirus/platform)" \ + --section "prefix(github.com/verygoodsoftwarenotvirus)" \ + --section default \ + / +``` + +If Docker is available, lint: + +```bash +docker run --rm --volume "$PWD:$PWD" --workdir="$PWD" \ + golangci/golangci-lint:v2.10.1 \ + golangci-lint run --timeout 5m .//... +``` + +If you touched a shared mock package, also lint it: +```bash +docker run ... golangci-lint run --timeout 5m .//mock/... +``` + +Must report `0 issues.` + +## Pitfalls that burned the pilot + +- **`mockmetrics` is the package name; the directory is `mock`.** Import `observability/metrics/mock` WITHOUT an alias; the identifier you use in code is `mockmetrics.ProviderMock`. Adding an alias is redundant and linters may flag it. +- **`circuitbreaking/mock` package name is `mock`.** You need an alias (`mockcircuitbreaking`) to avoid ambiguity when multiple files or packages in scope could shadow the identifier. +- **`-skip-ensure` is only for unexported interfaces.** Do not use it for exported ones — you'd lose the compile-time interface-satisfaction check. +- **moq's generated files trip `fieldalignment`** on anonymous struct returns (the `XxxCalls()` return type). Harmless: `fieldalignment -fix` returns 0, `make format` converges, and `golangci-lint` skips generated files via `exclusions.generated: lax`. Don't try to "fix" the generated file. +- **LSP diagnostics lag.** After deletes or moves, the IDE may report phantom errors. Trust `go build` — if it exits 0, the diagnostics are stale. +- **Don't break tests to make the mock pass.** If moq's nil-XxxFunc-panics behavior is firing during a test run, the test was exercising a code path that the testify version silently ignored. Either add the missing `XxxFunc` to the setup (if the call is legitimate) or understand why it's happening (the code under test may have changed). +- **The new call-count check replaces AssertExpectations semantically, not behaviorally.** testify's AssertExpectations fails if unexpected methods are called; moq's `XCalls()` length check only verifies the count of the one method you checked. If a test relied on "fail loudly if anything unexpected happens", you may need to explicitly set `XxxFunc = func(...) { t.Fatalf("unexpected call to X") }` for every method that should NOT be called. +- **Test file imports may look wrong right after a rewrite.** gci will reorder them when `make format` runs; gofmt will handle whitespace. But if the file doesn't compile, trust the compile error over the LSP. + +## Report format + +Under 40 lines. + +1. **Files modified** — grouped by: target test files touched, shared mock packages extended (if any), inline mocks added (if any). +2. **Generated files added** — list any new `*_mock.go` or `*_mock_test.go` files. +3. **Counts** — approximate: `.On/.Return → XxxFunc: N`, `AssertExpectations dropped: M`, `mock types renamed: K`. +4. **Bucket breakdown** — how many mocks fell into each bucket (A: already existed, B: generated into shared, C: inline generated). +5. **What WASN'T touched** — confirm assertions were left alone (testify assertions may still be present if `/convert-assertions` hasn't run on this package). +6. **Verification results** — ✓ on each check, or what failed. +7. **Manual follow-ups the user should review** — judgment calls on: dropped AssertExpectations that may have been exhaustive checks, dispatch maps where the mapping is non-obvious, any sequential-return translations, any new `t.Fatalf` guards added to catch "should not be called" cases. +8. **Git status summary** — short list of modified + untracked files. + +## Edge cases + +- **Target has no testify mocks** — report "nothing to do" and exit. +- **Target uses a mock from a shared package, and that package has multiple mock types (e.g., `observability/metrics/mock` has `MetricsProvider` AND `Int64Counter`), but the target only uses one** — only flip the one the target uses. Don't migrate the others preemptively; let another consumer's run of this skill pick them up. +- **Target has a shared helper** (e.g., `helpers_test.go`) that takes a `*mockX.MockFoo` parameter and is called from files in sub-packages — this is a widening change. Stop, describe the shared helper and its callers, and ask the user whether to (a) migrate the helper signature + all callers in one PR, (b) add an interface-typed parameter so both old and new mocks satisfy it during transition, or (c) skip this target and pick one without a cross-package helper entanglement. +- **Target has a TestMain or init() that sets up global mocks** — rare; if encountered, flag it and proceed carefully. +- **Target has generic interfaces** (like `cache.Cache[T]`) — moq handles generics. The generated mock will be `CacheMock[T any]`, call sites write `&mock.CacheMock[ConcreteType]{}`. If generation is needed, proceed normally; if call sites need the type parameter added, flag it in the report. +- **Target tests are a mix of real-backend (testcontainers) and mock-backed (testify/moq)** — convert only the mock-backed ones. Container tests don't use mocks. diff --git a/analytics/config/config_test.go b/analytics/config/config_test.go index 0cf5f67..fe45a89 100644 --- a/analytics/config/config_test.go +++ b/analytics/config/config_test.go @@ -12,10 +12,9 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/reflection" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -177,19 +176,18 @@ func TestConfig_ProvideCollector(T *testing.T) { }, } - i64Counter := &mockmetrics.Int64Counter{} - mp := &mockmetrics.MetricsProvider{} - mp.On( - reflection.GetMethodName(mp.NewInt64Counter), - mock.AnythingOfType("string"), - []metric.Int64CounterOption(nil), - ).Return(i64Counter, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.SliceEmpty(t, options) + return nil, errors.New("arbitrary") + }, + } reporter, err := cfg.ProvideCollector(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp) assert.Nil(t, reporter) assert.Error(t, err) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/analytics/mock/doc.go b/analytics/mock/doc.go new file mode 100644 index 0000000..f7acd5c --- /dev/null +++ b/analytics/mock/doc.go @@ -0,0 +1,8 @@ +/* +Package analyticsmock provides moq-generated mocks for the analytics package. +*/ +package analyticsmock + +// Regenerate the moq mocks via `go generate ./analytics/mock/`. + +//go:generate go tool github.com/matryer/moq -out event_reporter_mock.go -pkg analyticsmock -rm -fmt goimports .. EventReporter:EventReporterMock diff --git a/analytics/mock/event_reporter_mock.go b/analytics/mock/event_reporter_mock.go new file mode 100644 index 0000000..1012b5e --- /dev/null +++ b/analytics/mock/event_reporter_mock.go @@ -0,0 +1,250 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package analyticsmock + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/analytics" +) + +// Ensure, that EventReporterMock does implement analytics.EventReporter. +// If this is not the case, regenerate this file with moq. +var _ analytics.EventReporter = &EventReporterMock{} + +// EventReporterMock is a mock implementation of analytics.EventReporter. +// +// func TestSomethingThatUsesEventReporter(t *testing.T) { +// +// // make and configure a mocked analytics.EventReporter +// mockedEventReporter := &EventReporterMock{ +// AddUserFunc: func(ctx context.Context, userID string, properties map[string]any) error { +// panic("mock out the AddUser method") +// }, +// CloseFunc: func() { +// panic("mock out the Close method") +// }, +// EventOccurredFunc: func(ctx context.Context, event string, userID string, properties map[string]any) error { +// panic("mock out the EventOccurred method") +// }, +// EventOccurredAnonymousFunc: func(ctx context.Context, event string, anonymousID string, properties map[string]any) error { +// panic("mock out the EventOccurredAnonymous method") +// }, +// } +// +// // use mockedEventReporter in code that requires analytics.EventReporter +// // and then make assertions. +// +// } +type EventReporterMock struct { + // AddUserFunc mocks the AddUser method. + AddUserFunc func(ctx context.Context, userID string, properties map[string]any) error + + // CloseFunc mocks the Close method. + CloseFunc func() + + // EventOccurredFunc mocks the EventOccurred method. + EventOccurredFunc func(ctx context.Context, event string, userID string, properties map[string]any) error + + // EventOccurredAnonymousFunc mocks the EventOccurredAnonymous method. + EventOccurredAnonymousFunc func(ctx context.Context, event string, anonymousID string, properties map[string]any) error + + // calls tracks calls to the methods. + calls struct { + // AddUser holds details about calls to the AddUser method. + AddUser []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // UserID is the userID argument value. + UserID string + // Properties is the properties argument value. + Properties map[string]any + } + // Close holds details about calls to the Close method. + Close []struct { + } + // EventOccurred holds details about calls to the EventOccurred method. + EventOccurred []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Event is the event argument value. + Event string + // UserID is the userID argument value. + UserID string + // Properties is the properties argument value. + Properties map[string]any + } + // EventOccurredAnonymous holds details about calls to the EventOccurredAnonymous method. + EventOccurredAnonymous []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Event is the event argument value. + Event string + // AnonymousID is the anonymousID argument value. + AnonymousID string + // Properties is the properties argument value. + Properties map[string]any + } + } + lockAddUser sync.RWMutex + lockClose sync.RWMutex + lockEventOccurred sync.RWMutex + lockEventOccurredAnonymous sync.RWMutex +} + +// AddUser calls AddUserFunc. +func (mock *EventReporterMock) AddUser(ctx context.Context, userID string, properties map[string]any) error { + if mock.AddUserFunc == nil { + panic("EventReporterMock.AddUserFunc: method is nil but EventReporter.AddUser was just called") + } + callInfo := struct { + Ctx context.Context + UserID string + Properties map[string]any + }{ + Ctx: ctx, + UserID: userID, + Properties: properties, + } + mock.lockAddUser.Lock() + mock.calls.AddUser = append(mock.calls.AddUser, callInfo) + mock.lockAddUser.Unlock() + return mock.AddUserFunc(ctx, userID, properties) +} + +// AddUserCalls gets all the calls that were made to AddUser. +// Check the length with: +// +// len(mockedEventReporter.AddUserCalls()) +func (mock *EventReporterMock) AddUserCalls() []struct { + Ctx context.Context + UserID string + Properties map[string]any +} { + var calls []struct { + Ctx context.Context + UserID string + Properties map[string]any + } + mock.lockAddUser.RLock() + calls = mock.calls.AddUser + mock.lockAddUser.RUnlock() + return calls +} + +// Close calls CloseFunc. +func (mock *EventReporterMock) Close() { + if mock.CloseFunc == nil { + panic("EventReporterMock.CloseFunc: method is nil but EventReporter.Close was just called") + } + callInfo := struct { + }{} + mock.lockClose.Lock() + mock.calls.Close = append(mock.calls.Close, callInfo) + mock.lockClose.Unlock() + mock.CloseFunc() +} + +// CloseCalls gets all the calls that were made to Close. +// Check the length with: +// +// len(mockedEventReporter.CloseCalls()) +func (mock *EventReporterMock) CloseCalls() []struct { +} { + var calls []struct { + } + mock.lockClose.RLock() + calls = mock.calls.Close + mock.lockClose.RUnlock() + return calls +} + +// EventOccurred calls EventOccurredFunc. +func (mock *EventReporterMock) EventOccurred(ctx context.Context, event string, userID string, properties map[string]any) error { + if mock.EventOccurredFunc == nil { + panic("EventReporterMock.EventOccurredFunc: method is nil but EventReporter.EventOccurred was just called") + } + callInfo := struct { + Ctx context.Context + Event string + UserID string + Properties map[string]any + }{ + Ctx: ctx, + Event: event, + UserID: userID, + Properties: properties, + } + mock.lockEventOccurred.Lock() + mock.calls.EventOccurred = append(mock.calls.EventOccurred, callInfo) + mock.lockEventOccurred.Unlock() + return mock.EventOccurredFunc(ctx, event, userID, properties) +} + +// EventOccurredCalls gets all the calls that were made to EventOccurred. +// Check the length with: +// +// len(mockedEventReporter.EventOccurredCalls()) +func (mock *EventReporterMock) EventOccurredCalls() []struct { + Ctx context.Context + Event string + UserID string + Properties map[string]any +} { + var calls []struct { + Ctx context.Context + Event string + UserID string + Properties map[string]any + } + mock.lockEventOccurred.RLock() + calls = mock.calls.EventOccurred + mock.lockEventOccurred.RUnlock() + return calls +} + +// EventOccurredAnonymous calls EventOccurredAnonymousFunc. +func (mock *EventReporterMock) EventOccurredAnonymous(ctx context.Context, event string, anonymousID string, properties map[string]any) error { + if mock.EventOccurredAnonymousFunc == nil { + panic("EventReporterMock.EventOccurredAnonymousFunc: method is nil but EventReporter.EventOccurredAnonymous was just called") + } + callInfo := struct { + Ctx context.Context + Event string + AnonymousID string + Properties map[string]any + }{ + Ctx: ctx, + Event: event, + AnonymousID: anonymousID, + Properties: properties, + } + mock.lockEventOccurredAnonymous.Lock() + mock.calls.EventOccurredAnonymous = append(mock.calls.EventOccurredAnonymous, callInfo) + mock.lockEventOccurredAnonymous.Unlock() + return mock.EventOccurredAnonymousFunc(ctx, event, anonymousID, properties) +} + +// EventOccurredAnonymousCalls gets all the calls that were made to EventOccurredAnonymous. +// Check the length with: +// +// len(mockedEventReporter.EventOccurredAnonymousCalls()) +func (mock *EventReporterMock) EventOccurredAnonymousCalls() []struct { + Ctx context.Context + Event string + AnonymousID string + Properties map[string]any +} { + var calls []struct { + Ctx context.Context + Event string + AnonymousID string + Properties map[string]any + } + mock.lockEventOccurredAnonymous.RLock() + calls = mock.calls.EventOccurredAnonymous + mock.lockEventOccurredAnonymous.RUnlock() + return calls +} diff --git a/analytics/mock/mock.go b/analytics/mock/mock.go deleted file mode 100644 index 112be41..0000000 --- a/analytics/mock/mock.go +++ /dev/null @@ -1,38 +0,0 @@ -package analyticsmock - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/analytics" - - "github.com/stretchr/testify/mock" -) - -var _ analytics.EventReporter = (*EventReporter)(nil) - -type ( - // EventReporter represents a service that can collect customer data. - EventReporter struct { - mock.Mock - } -) - -// Close implements the EventReporter interface. -func (m *EventReporter) Close() { - m.Called() -} - -// AddUser implements the EventReporter interface. -func (m *EventReporter) AddUser(ctx context.Context, userID string, properties map[string]any) error { - return m.Called(ctx, userID, properties).Error(0) -} - -// EventOccurred implements the EventReporter interface. -func (m *EventReporter) EventOccurred(ctx context.Context, event, userID string, properties map[string]any) error { - return m.Called(ctx, event, userID, properties).Error(0) -} - -// EventOccurredAnonymous implements the EventReporter interface. -func (m *EventReporter) EventOccurredAnonymous(ctx context.Context, event, anonymousID string, properties map[string]any) error { - return m.Called(ctx, event, anonymousID, properties).Error(0) -} diff --git a/analytics/multisource/reporter_test.go b/analytics/multisource/reporter_test.go index faa5f2a..28acfe2 100644 --- a/analytics/multisource/reporter_test.go +++ b/analytics/multisource/reporter_test.go @@ -8,8 +8,8 @@ import ( analyticsmock "github.com/verygoodsoftwarenotvirus/platform/v5/analytics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/analytics/noop" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -80,10 +80,15 @@ func TestMultiSourceEventReporter_TrackEvent(T *testing.T) { T.Run("delegates to correct reporter", func(t *testing.T) { t.Parallel() - mockReporter := &analyticsmock.EventReporter{} - mockReporter.On("EventOccurred", mock.AnythingOfType("*context.valueCtx"), "signup", "user1", mock.MatchedBy(func(props map[string]any) bool { - return props[SourcePropertyKey] == "ios" && props["plan"] == "pro" - })).Return(nil) + mockReporter := &analyticsmock.EventReporterMock{ + EventOccurredFunc: func(_ context.Context, event, userID string, properties map[string]any) error { + assert.Equal(t, "signup", event) + assert.Equal(t, "user1", userID) + assert.Equal(t, "ios", properties[SourcePropertyKey]) + assert.Equal(t, "pro", properties["plan"]) + return nil + }, + } reporters := map[string]analytics.EventReporter{ "ios": mockReporter, @@ -93,7 +98,7 @@ func TestMultiSourceEventReporter_TrackEvent(T *testing.T) { err := m.TrackEvent(context.Background(), "ios", "signup", "user1", map[string]any{"plan": "pro"}) assert.NoError(t, err) - mock.AssertExpectationsForObjects(t, mockReporter) + test.SliceLen(t, 1, mockReporter.EventOccurredCalls()) }) T.Run("uses noop for unknown source", func(t *testing.T) { @@ -112,10 +117,14 @@ func TestMultiSourceEventReporter_TrackAnonymousEvent(T *testing.T) { T.Run("delegates to correct reporter", func(t *testing.T) { t.Parallel() - mockReporter := &analyticsmock.EventReporter{} - mockReporter.On("EventOccurredAnonymous", mock.AnythingOfType("*context.valueCtx"), "page_view", "anon1", mock.MatchedBy(func(props map[string]any) bool { - return props[SourcePropertyKey] == "web" - })).Return(nil) + mockReporter := &analyticsmock.EventReporterMock{ + EventOccurredAnonymousFunc: func(_ context.Context, event, anonymousID string, properties map[string]any) error { + assert.Equal(t, "page_view", event) + assert.Equal(t, "anon1", anonymousID) + assert.Equal(t, "web", properties[SourcePropertyKey]) + return nil + }, + } reporters := map[string]analytics.EventReporter{ "web": mockReporter, @@ -125,7 +134,7 @@ func TestMultiSourceEventReporter_TrackAnonymousEvent(T *testing.T) { err := m.TrackAnonymousEvent(context.Background(), "web", "page_view", "anon1", map[string]any{}) assert.NoError(t, err) - mock.AssertExpectationsForObjects(t, mockReporter) + test.SliceLen(t, 1, mockReporter.EventOccurredAnonymousCalls()) }) } diff --git a/analytics/posthog/posthog_test.go b/analytics/posthog/posthog_test.go index 2b7e628..85c28f8 100644 --- a/analytics/posthog/posthog_test.go +++ b/analytics/posthog/posthog_test.go @@ -11,7 +11,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/mock" + "github.com/shoenig/test" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -44,28 +44,41 @@ func TestNewPostHogEventReporter(T *testing.T) { T.Run("with error creating event counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_events", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_events", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } collector, err := NewPostHogEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, t.Name(), cbnoop.NewCircuitBreaker()) require.Error(t, err) require.Nil(t, collector) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_events", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case name + "_events": + return metrics.Int64CounterForTest(t, "x"), nil + case name + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } collector, err := NewPostHogEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, t.Name(), cbnoop.NewCircuitBreaker()) require.Error(t, err) require.Nil(t, collector) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) } diff --git a/analytics/rudderstack/rudderstack_test.go b/analytics/rudderstack/rudderstack_test.go index d19c119..5436fc1 100644 --- a/analytics/rudderstack/rudderstack_test.go +++ b/analytics/rudderstack/rudderstack_test.go @@ -11,7 +11,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/mock" + "github.com/shoenig/test" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -79,14 +79,18 @@ func TestNewRudderstackEventReporter(T *testing.T) { DataPlaneURL: t.Name(), } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_events", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_events", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } collector, err := NewRudderstackEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, cfg, cbnoop.NewCircuitBreaker()) require.Error(t, err) require.Nil(t, collector) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { @@ -97,15 +101,24 @@ func TestNewRudderstackEventReporter(T *testing.T) { DataPlaneURL: t.Name(), } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_events", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case name + "_events": + return metrics.Int64CounterForTest(t, "x"), nil + case name + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } collector, err := NewRudderstackEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, cfg, cbnoop.NewCircuitBreaker()) require.Error(t, err) require.Nil(t, collector) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) } diff --git a/analytics/segment/segment_test.go b/analytics/segment/segment_test.go index f45e8bc..b03d21a 100644 --- a/analytics/segment/segment_test.go +++ b/analytics/segment/segment_test.go @@ -11,7 +11,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/mock" + "github.com/shoenig/test" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -42,28 +42,41 @@ func TestNewSegmentEventReporter(T *testing.T) { T.Run("with error creating event counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_events", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_events", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } collector, err := NewSegmentEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, t.Name(), cbnoop.NewCircuitBreaker()) require.Error(t, err) require.Nil(t, collector) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_events", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case name + "_events": + return metrics.Int64CounterForTest(t, "x"), nil + case name + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } collector, err := NewSegmentEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, t.Name(), cbnoop.NewCircuitBreaker()) require.Error(t, err) require.Nil(t, collector) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) } diff --git a/cache/config/config_test.go b/cache/config/config_test.go index 4ef6685..98d9590 100644 --- a/cache/config/config_test.go +++ b/cache/config/config_test.go @@ -8,7 +8,7 @@ import ( circuitbreakingcfg "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/config" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - metricsmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock2" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" @@ -129,7 +129,7 @@ func TestProvideCache(T *testing.T) { }, } - mp := &metricsmock2.ProviderMock{ + mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { test.EqOp(t, "redis-cache-breaker_circuit_breaker_tripped", name) return nil, errors.New("counter init failure") diff --git a/cache/redis/mocks_gen_test.go b/cache/redis/mocks_gen_test.go index 8748962..51ff344 100644 --- a/cache/redis/mocks_gen_test.go +++ b/cache/redis/mocks_gen_test.go @@ -4,7 +4,7 @@ package redis // redisClient interface is unexported (it's a test seam), so its mock lives // in-package as a *_test.go file rather than under a sibling mock package. // Mocks for the external interfaces (metrics.Provider, circuitbreaking.CircuitBreaker) -// live alongside those interfaces in observability/metrics/mock2 and -// circuitbreaking/mock2. +// live alongside those interfaces in observability/metrics/mock and +// circuitbreaking/mock. //go:generate go tool github.com/matryer/moq -out redisclient_mock_test.go -pkg redis -rm -fmt goimports . redisClient diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index c2eda40..d06f6b8 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -10,10 +10,10 @@ import ( "time" "github.com/verygoodsoftwarenotvirus/platform/v5/cache" - circuitbreakingmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock2" + mockcircuitbreaking "github.com/verygoodsoftwarenotvirus/platform/v5/circuitbreaking/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - metricsmock2 "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock2" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/go-redis/redis/v8" @@ -41,7 +41,7 @@ func gobEncodeExample(t *testing.T, e *example) string { return buf.String() } -func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *circuitbreakingmock2.CircuitBreakerMock) { +func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *mockcircuitbreaking.CircuitBreakerMock) { t.Helper() mp := metrics.NewNoopMetricsProvider() @@ -65,7 +65,7 @@ func buildTestImpl(t *testing.T) (*redisCacheImpl[example], *redisClientMock, *c must.NoError(t, err) client := &redisClientMock{} - cb := &circuitbreakingmock2.CircuitBreakerMock{} + cb := &mockcircuitbreaking.CircuitBreakerMock{} return &redisCacheImpl[example]{ logger: logging.NewNoopLogger(), @@ -91,9 +91,9 @@ type counterResult struct { // newCounterProviderMock returns a metrics.Provider mock whose NewInt64Counter // implementation looks up the result keyed on the counter name. Unknown names // fail the test. -func newCounterProviderMock(t *testing.T, results map[string]counterResult) *metricsmock2.ProviderMock { +func newCounterProviderMock(t *testing.T, results map[string]counterResult) *mockmetrics.ProviderMock { t.Helper() - return &metricsmock2.ProviderMock{ + return &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(metricName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { res, ok := results[metricName] if !ok { @@ -256,7 +256,7 @@ func TestNewRedisCache(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") must.NoError(t, histErr) - mp := &metricsmock2.ProviderMock{ + mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { return metrics.Int64CounterForTest(t, "x"), nil }, diff --git a/capitalism/mock/doc.go b/capitalism/mock/doc.go new file mode 100644 index 0000000..2446205 --- /dev/null +++ b/capitalism/mock/doc.go @@ -0,0 +1,9 @@ +// Package capitalismmock provides mock implementations of the capitalism package's +// interfaces. Both the hand-written testify-based MockPaymentManager and the +// moq-generated PaymentManagerMock live here during the testify → moq migration. +// New test code should prefer PaymentManagerMock. +package capitalismmock + +// Regenerate the moq mocks via `go generate ./capitalism/mock/`. + +//go:generate go tool github.com/matryer/moq -out payment_manager_mock.go -pkg capitalismmock -rm -fmt goimports .. PaymentManager:PaymentManagerMock diff --git a/capitalism/mock/mock_payment_manager.go b/capitalism/mock/mock_payment_manager.go deleted file mode 100644 index e1e4a11..0000000 --- a/capitalism/mock/mock_payment_manager.go +++ /dev/null @@ -1,26 +0,0 @@ -package capitalismmock - -import ( - "net/http" - - "github.com/verygoodsoftwarenotvirus/platform/v5/capitalism" - - "github.com/stretchr/testify/mock" -) - -var _ capitalism.PaymentManager = (*MockPaymentManager)(nil) - -// MockPaymentManager is a mockable capitalism.PaymentManager. -type MockPaymentManager struct { - mock.Mock -} - -// NewMockPaymentManager returns a mockable capitalism.PaymentManager. -func NewMockPaymentManager() *MockPaymentManager { - return &MockPaymentManager{} -} - -// HandleEventWebhook satisfies our interface contract. -func (m *MockPaymentManager) HandleEventWebhook(req *http.Request) error { - return m.Called(req).Error(0) -} diff --git a/capitalism/mock/payment_manager_mock.go b/capitalism/mock/payment_manager_mock.go new file mode 100644 index 0000000..81a23e3 --- /dev/null +++ b/capitalism/mock/payment_manager_mock.go @@ -0,0 +1,77 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package capitalismmock + +import ( + "net/http" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/capitalism" +) + +// Ensure, that PaymentManagerMock does implement capitalism.PaymentManager. +// If this is not the case, regenerate this file with moq. +var _ capitalism.PaymentManager = &PaymentManagerMock{} + +// PaymentManagerMock is a mock implementation of capitalism.PaymentManager. +// +// func TestSomethingThatUsesPaymentManager(t *testing.T) { +// +// // make and configure a mocked capitalism.PaymentManager +// mockedPaymentManager := &PaymentManagerMock{ +// HandleEventWebhookFunc: func(req *http.Request) error { +// panic("mock out the HandleEventWebhook method") +// }, +// } +// +// // use mockedPaymentManager in code that requires capitalism.PaymentManager +// // and then make assertions. +// +// } +type PaymentManagerMock struct { + // HandleEventWebhookFunc mocks the HandleEventWebhook method. + HandleEventWebhookFunc func(req *http.Request) error + + // calls tracks calls to the methods. + calls struct { + // HandleEventWebhook holds details about calls to the HandleEventWebhook method. + HandleEventWebhook []struct { + // Req is the req argument value. + Req *http.Request + } + } + lockHandleEventWebhook sync.RWMutex +} + +// HandleEventWebhook calls HandleEventWebhookFunc. +func (mock *PaymentManagerMock) HandleEventWebhook(req *http.Request) error { + if mock.HandleEventWebhookFunc == nil { + panic("PaymentManagerMock.HandleEventWebhookFunc: method is nil but PaymentManager.HandleEventWebhook was just called") + } + callInfo := struct { + Req *http.Request + }{ + Req: req, + } + mock.lockHandleEventWebhook.Lock() + mock.calls.HandleEventWebhook = append(mock.calls.HandleEventWebhook, callInfo) + mock.lockHandleEventWebhook.Unlock() + return mock.HandleEventWebhookFunc(req) +} + +// HandleEventWebhookCalls gets all the calls that were made to HandleEventWebhook. +// Check the length with: +// +// len(mockedPaymentManager.HandleEventWebhookCalls()) +func (mock *PaymentManagerMock) HandleEventWebhookCalls() []struct { + Req *http.Request +} { + var calls []struct { + Req *http.Request + } + mock.lockHandleEventWebhook.RLock() + calls = mock.calls.HandleEventWebhook + mock.lockHandleEventWebhook.RUnlock() + return calls +} diff --git a/capitalism/stripe/mock_backend_test.go b/capitalism/stripe/mock_backend_test.go deleted file mode 100644 index 9fb9ed6..0000000 --- a/capitalism/stripe/mock_backend_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package stripe - -import ( - "bytes" - "encoding/json" - "testing" - - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "github.com/stripe/stripe-go/v75" - "github.com/stripe/stripe-go/v75/form" -) - -var _ stripe.Backend = (*mockBackend)(nil) - -type mockBackend struct { - mock.Mock - - anticipatedReturns [][]byte -} - -func (m *mockBackend) AnticipateCall(t *testing.T, v any) { - t.Helper() - - b, err := json.Marshal(v) - require.NoError(t, err) - - m.anticipatedReturns = append(m.anticipatedReturns, b) -} - -func (m *mockBackend) Call(method, path, key string, params stripe.ParamsContainer, v stripe.LastResponseSetter) error { - b := m.anticipatedReturns[0] - m.anticipatedReturns = append(m.anticipatedReturns[:0], m.anticipatedReturns[1:]...) - - if err := json.Unmarshal(b, v); err != nil { - panic(err) - } - - return m.Called(method, path, key, params, v).Error(0) -} - -func (m *mockBackend) CallRaw(method, path, key string, body *form.Values, params *stripe.Params, v stripe.LastResponseSetter) error { - b := m.anticipatedReturns[0] - m.anticipatedReturns = append(m.anticipatedReturns[:0], m.anticipatedReturns[1:]...) - - if err := json.Unmarshal(b, v); err != nil { - panic(err) - } - - return m.Called(method, path, key, body, params, v).Error(0) -} - -func (m *mockBackend) CallStreaming(method, path, key string, params stripe.ParamsContainer, v stripe.StreamingLastResponseSetter) error { - b := m.anticipatedReturns[0] - m.anticipatedReturns = append(m.anticipatedReturns[:0], m.anticipatedReturns[1:]...) - - if err := json.Unmarshal(b, v); err != nil { - panic(err) - } - - return m.Called(method, path, key, params, v).Error(0) -} - -func (m *mockBackend) CallMultipart(method, path, key, boundary string, body *bytes.Buffer, params *stripe.Params, v stripe.LastResponseSetter) error { - b := m.anticipatedReturns[0] - m.anticipatedReturns = append(m.anticipatedReturns[:0], m.anticipatedReturns[1:]...) - - if err := json.Unmarshal(b, v); err != nil { - panic(err) - } - - return m.Called(method, path, key, boundary, body, params, v).Error(0) -} - -func (m *mockBackend) SetMaxNetworkRetries(maxNetworkRetries int64) { - m.Called(maxNetworkRetries) -} diff --git a/capitalism/stripe/stripe_test.go b/capitalism/stripe/stripe_test.go index 977358a..8701ff8 100644 --- a/capitalism/stripe/stripe_test.go +++ b/capitalism/stripe/stripe_test.go @@ -2,6 +2,7 @@ package stripe import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -13,8 +14,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/verygoodsoftwarenotvirus/platform/v5/random" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stripe/stripe-go/v75" "github.com/stripe/stripe-go/v75/webhook" @@ -110,8 +111,11 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { require.NoError(t, err) eventPayload := pm.encoderDecoder.MustEncode(ctx, event) - encoderDecoder := mockencoding.NewMockEncoderDecoder() - encoderDecoder.On("DecodeBytes", mock.Anything, mock.Anything, mock.Anything).Return(nil) + encoderDecoder := &mockencoding.ServerEncoderDecoderMock{ + DecodeBytesFunc: func(_ context.Context, _ []byte, _ any) error { + return nil + }, + } pm.encoderDecoder = encoderDecoder req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", bytes.NewReader(eventPayload)) @@ -122,7 +126,7 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { err = pm.HandleEventWebhook(req) assert.NoError(t, err) - mock.AssertExpectationsForObjects(t, encoderDecoder) + test.SliceLen(t, 1, encoderDecoder.DecodeBytesCalls()) }) T.Run("with error reading body", func(t *testing.T) { @@ -191,8 +195,11 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { require.NoError(t, err) eventPayload := pm.encoderDecoder.MustEncode(ctx, event) - encoderDecoder := mockencoding.NewMockEncoderDecoder() - encoderDecoder.On("DecodeBytes", mock.Anything, mock.Anything, mock.Anything).Return(fmt.Errorf("decode error")) + encoderDecoder := &mockencoding.ServerEncoderDecoderMock{ + DecodeBytesFunc: func(_ context.Context, _ []byte, _ any) error { + return fmt.Errorf("decode error") + }, + } pm.encoderDecoder = encoderDecoder req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", bytes.NewReader(eventPayload)) @@ -203,7 +210,7 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { err = pm.HandleEventWebhook(req) assert.Error(t, err) - mock.AssertExpectationsForObjects(t, encoderDecoder) + test.SliceLen(t, 1, encoderDecoder.DecodeBytesCalls()) }) T.Run("with unhandled event type", func(t *testing.T) { diff --git a/circuitbreaking/config/config_test.go b/circuitbreaking/config/config_test.go index 87c227f..40581ed 100644 --- a/circuitbreaking/config/config_test.go +++ b/circuitbreaking/config/config_test.go @@ -1,6 +1,7 @@ package circuitbreakingcfg import ( + "context" "errors" "fmt" "testing" @@ -10,11 +11,10 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" - "github.com/verygoodsoftwarenotvirus/platform/v5/reflection" circuit "github.com/rubyist/circuitbreaker" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "go.opentelemetry.io/otel/metric" ) @@ -110,16 +110,19 @@ func TestProvideCircuitBreakerFromConfig(T *testing.T) { cfg.EnsureDefaults() ctx := t.Context() - i64Counter := &mockmetrics.Int64Counter{} - mp := &mockmetrics.MetricsProvider{} - mp.On(reflection.GetMethodName(mp.NewInt64Counter), fmt.Sprintf("%s_circuit_breaker_tripped", cfg.Name), []metric.Int64CounterOption(nil)).Return(i64Counter, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, fmt.Sprintf("%s_circuit_breaker_tripped", cfg.Name), counterName) + return &mockmetrics.Int64CounterMock{}, errors.New("arbitrary") + }, + } cb, err := ProvideCircuitBreakerFromConfig(ctx, cfg, logging.NewNoopLogger(), mp) assert.Nil(t, cb) assert.Error(t, err) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error providing second metric", func(t *testing.T) { @@ -127,17 +130,25 @@ func TestProvideCircuitBreakerFromConfig(T *testing.T) { cfg.EnsureDefaults() ctx := t.Context() - i64Counter := &mockmetrics.Int64Counter{} - mp := &mockmetrics.MetricsProvider{} - mp.On(reflection.GetMethodName(mp.NewInt64Counter), fmt.Sprintf("%s_circuit_breaker_tripped", cfg.Name), []metric.Int64CounterOption(nil)).Return(i64Counter, nil) - mp.On(reflection.GetMethodName(mp.NewInt64Counter), fmt.Sprintf("%s_circuit_breaker_failed", cfg.Name), []metric.Int64CounterOption(nil)).Return(i64Counter, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case fmt.Sprintf("%s_circuit_breaker_tripped", cfg.Name): + return &mockmetrics.Int64CounterMock{}, nil + case fmt.Sprintf("%s_circuit_breaker_failed", cfg.Name): + return &mockmetrics.Int64CounterMock{}, errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } cb, err := ProvideCircuitBreakerFromConfig(ctx, cfg, logging.NewNoopLogger(), mp) assert.Nil(t, cb) assert.Error(t, err) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error providing third metric", func(t *testing.T) { @@ -145,18 +156,26 @@ func TestProvideCircuitBreakerFromConfig(T *testing.T) { cfg.EnsureDefaults() ctx := t.Context() - i64Counter := &mockmetrics.Int64Counter{} - mp := &mockmetrics.MetricsProvider{} - mp.On(reflection.GetMethodName(mp.NewInt64Counter), fmt.Sprintf("%s_circuit_breaker_tripped", cfg.Name), []metric.Int64CounterOption(nil)).Return(i64Counter, nil) - mp.On(reflection.GetMethodName(mp.NewInt64Counter), fmt.Sprintf("%s_circuit_breaker_failed", cfg.Name), []metric.Int64CounterOption(nil)).Return(i64Counter, nil) - mp.On(reflection.GetMethodName(mp.NewInt64Counter), fmt.Sprintf("%s_circuit_breaker_reset", cfg.Name), []metric.Int64CounterOption(nil)).Return(i64Counter, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case fmt.Sprintf("%s_circuit_breaker_tripped", cfg.Name), + fmt.Sprintf("%s_circuit_breaker_failed", cfg.Name): + return &mockmetrics.Int64CounterMock{}, nil + case fmt.Sprintf("%s_circuit_breaker_reset", cfg.Name): + return &mockmetrics.Int64CounterMock{}, errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } cb, err := ProvideCircuitBreakerFromConfig(ctx, cfg, logging.NewNoopLogger(), mp) assert.Nil(t, cb) assert.Error(t, err) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 3, mp.NewInt64CounterCalls()) }) } @@ -271,13 +290,20 @@ func TestHandleCircuitBreakerEvents(T *testing.T) { T.Run("handles all event types and exits on channel close", func(t *testing.T) { ctx := t.Context() - i64Counter := &mockmetrics.Int64Counter{} - i64Counter.On("Add", mock.Anything, int64(1), []metric.AddOption(nil)).Return() + i64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } - mp := &mockmetrics.MetricsProvider{} - mp.On(reflection.GetMethodName(mp.NewInt64Counter), "failure", []metric.Int64CounterOption(nil)).Return(i64Counter, nil) - mp.On(reflection.GetMethodName(mp.NewInt64Counter), "reset", []metric.Int64CounterOption(nil)).Return(i64Counter, nil) - mp.On(reflection.GetMethodName(mp.NewInt64Counter), "broken", []metric.Int64CounterOption(nil)).Return(i64Counter, nil) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case "failure", "reset", "broken": + return i64Counter, nil + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } failure, err := mp.NewInt64Counter("failure") assert.NoError(t, err) @@ -295,7 +321,8 @@ func TestHandleCircuitBreakerEvents(T *testing.T) { handleCircuitBreakerEvents(ctx, logging.NewNoopLogger(), events, failure, reset, broken) - mock.AssertExpectationsForObjects(t, i64Counter, mp) + test.SliceLen(t, 3, mp.NewInt64CounterCalls()) + test.SliceLen(t, 3, i64Counter.AddCalls()) }) } diff --git a/circuitbreaking/mock2/circuitbreaker_mock.go b/circuitbreaking/mock/circuitbreaker_mock.go similarity index 99% rename from circuitbreaking/mock2/circuitbreaker_mock.go rename to circuitbreaking/mock/circuitbreaker_mock.go index 23de9f7..eeffec8 100644 --- a/circuitbreaking/mock2/circuitbreaker_mock.go +++ b/circuitbreaking/mock/circuitbreaker_mock.go @@ -1,7 +1,7 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package mock2 +package mock import ( "sync" diff --git a/circuitbreaking/mock/doc.go b/circuitbreaking/mock/doc.go new file mode 100644 index 0000000..a803e94 --- /dev/null +++ b/circuitbreaking/mock/doc.go @@ -0,0 +1,9 @@ +// Package mock provides mock implementations of the circuitbreaking package's +// interfaces. Both the hand-written testify-based MockCircuitBreaker and the +// moq-generated CircuitBreakerMock live here during the testify → moq +// migration. New test code should prefer CircuitBreakerMock. +package mock + +// Regenerate the moq mocks via `go generate ./circuitbreaking/mock/`. + +//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock.go -pkg mock -rm -fmt goimports .. CircuitBreaker:CircuitBreakerMock diff --git a/circuitbreaking/mock/mock.go b/circuitbreaking/mock/mock.go deleted file mode 100644 index 78d3df7..0000000 --- a/circuitbreaking/mock/mock.go +++ /dev/null @@ -1,23 +0,0 @@ -package mock - -import "github.com/stretchr/testify/mock" - -type MockCircuitBreaker struct { - mock.Mock -} - -func (n *MockCircuitBreaker) Failed() { - n.Called() -} - -func (n *MockCircuitBreaker) Succeeded() { - n.Called() -} - -func (n *MockCircuitBreaker) CanProceed() bool { - return n.Called().Bool(0) -} - -func (n *MockCircuitBreaker) CannotProceed() bool { - return n.Called().Bool(0) -} diff --git a/circuitbreaking/mock2/doc.go b/circuitbreaking/mock2/doc.go deleted file mode 100644 index 3912fcf..0000000 --- a/circuitbreaking/mock2/doc.go +++ /dev/null @@ -1,10 +0,0 @@ -// Package mock2 provides moq-generated mock implementations of interfaces in -// the circuitbreaking package. It exists alongside the hand-written -// testify-based package circuitbreaking/mock and is a pilot of the -// matryer/moq workflow; consumers that want the moq style should import this -// package instead of circuitbreaking/mock. -package mock2 - -// Regenerate via `go generate ./circuitbreaking/mock2/`. - -//go:generate go tool github.com/matryer/moq -out circuitbreaker_mock.go -pkg mock2 -rm -fmt goimports .. CircuitBreaker:CircuitBreakerMock diff --git a/cryptography/encryption/mock/doc.go b/cryptography/encryption/mock/doc.go index 2e27cb3..27ea356 100644 --- a/cryptography/encryption/mock/doc.go +++ b/cryptography/encryption/mock/doc.go @@ -1,4 +1,11 @@ /* -Package encryptionmock contains the interfaces and implementations for encrypting and decrypting data. +Package encryptionmock provides mock implementations of the encryption package's +interfaces. Both the hand-written testify-based MockImpl and the moq-generated +EncryptorDecryptorMock live here during the testify → moq migration. New test +code should prefer EncryptorDecryptorMock. */ package encryptionmock + +// Regenerate the moq mocks via `go generate ./cryptography/encryption/mock/`. + +//go:generate go tool github.com/matryer/moq -out encryptor_decryptor_mock.go -pkg encryptionmock -rm -fmt goimports .. EncryptorDecryptor:EncryptorDecryptorMock diff --git a/cryptography/encryption/mock/encryptor_decryptor_mock.go b/cryptography/encryption/mock/encryptor_decryptor_mock.go new file mode 100644 index 0000000..deca8d8 --- /dev/null +++ b/cryptography/encryption/mock/encryptor_decryptor_mock.go @@ -0,0 +1,133 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package encryptionmock + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/cryptography/encryption" +) + +// Ensure, that EncryptorDecryptorMock does implement encryption.EncryptorDecryptor. +// If this is not the case, regenerate this file with moq. +var _ encryption.EncryptorDecryptor = &EncryptorDecryptorMock{} + +// EncryptorDecryptorMock is a mock implementation of encryption.EncryptorDecryptor. +// +// func TestSomethingThatUsesEncryptorDecryptor(t *testing.T) { +// +// // make and configure a mocked encryption.EncryptorDecryptor +// mockedEncryptorDecryptor := &EncryptorDecryptorMock{ +// DecryptFunc: func(ctx context.Context, content string) (string, error) { +// panic("mock out the Decrypt method") +// }, +// EncryptFunc: func(ctx context.Context, content string) (string, error) { +// panic("mock out the Encrypt method") +// }, +// } +// +// // use mockedEncryptorDecryptor in code that requires encryption.EncryptorDecryptor +// // and then make assertions. +// +// } +type EncryptorDecryptorMock struct { + // DecryptFunc mocks the Decrypt method. + DecryptFunc func(ctx context.Context, content string) (string, error) + + // EncryptFunc mocks the Encrypt method. + EncryptFunc func(ctx context.Context, content string) (string, error) + + // calls tracks calls to the methods. + calls struct { + // Decrypt holds details about calls to the Decrypt method. + Decrypt []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Content is the content argument value. + Content string + } + // Encrypt holds details about calls to the Encrypt method. + Encrypt []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Content is the content argument value. + Content string + } + } + lockDecrypt sync.RWMutex + lockEncrypt sync.RWMutex +} + +// Decrypt calls DecryptFunc. +func (mock *EncryptorDecryptorMock) Decrypt(ctx context.Context, content string) (string, error) { + if mock.DecryptFunc == nil { + panic("EncryptorDecryptorMock.DecryptFunc: method is nil but EncryptorDecryptor.Decrypt was just called") + } + callInfo := struct { + Ctx context.Context + Content string + }{ + Ctx: ctx, + Content: content, + } + mock.lockDecrypt.Lock() + mock.calls.Decrypt = append(mock.calls.Decrypt, callInfo) + mock.lockDecrypt.Unlock() + return mock.DecryptFunc(ctx, content) +} + +// DecryptCalls gets all the calls that were made to Decrypt. +// Check the length with: +// +// len(mockedEncryptorDecryptor.DecryptCalls()) +func (mock *EncryptorDecryptorMock) DecryptCalls() []struct { + Ctx context.Context + Content string +} { + var calls []struct { + Ctx context.Context + Content string + } + mock.lockDecrypt.RLock() + calls = mock.calls.Decrypt + mock.lockDecrypt.RUnlock() + return calls +} + +// Encrypt calls EncryptFunc. +func (mock *EncryptorDecryptorMock) Encrypt(ctx context.Context, content string) (string, error) { + if mock.EncryptFunc == nil { + panic("EncryptorDecryptorMock.EncryptFunc: method is nil but EncryptorDecryptor.Encrypt was just called") + } + callInfo := struct { + Ctx context.Context + Content string + }{ + Ctx: ctx, + Content: content, + } + mock.lockEncrypt.Lock() + mock.calls.Encrypt = append(mock.calls.Encrypt, callInfo) + mock.lockEncrypt.Unlock() + return mock.EncryptFunc(ctx, content) +} + +// EncryptCalls gets all the calls that were made to Encrypt. +// Check the length with: +// +// len(mockedEncryptorDecryptor.EncryptCalls()) +func (mock *EncryptorDecryptorMock) EncryptCalls() []struct { + Ctx context.Context + Content string +} { + var calls []struct { + Ctx context.Context + Content string + } + mock.lockEncrypt.RLock() + calls = mock.calls.Encrypt + mock.lockEncrypt.RUnlock() + return calls +} diff --git a/cryptography/encryption/mock/mock.go b/cryptography/encryption/mock/mock.go deleted file mode 100644 index 35d156d..0000000 --- a/cryptography/encryption/mock/mock.go +++ /dev/null @@ -1,28 +0,0 @@ -package encryptionmock - -import ( - "context" - - "github.com/stretchr/testify/mock" -) - -// MockImpl is the mock EncryptorDecryptor implementation. -type MockImpl struct { - mock.Mock -} - -func NewMockEncryptorDecryptor() *MockImpl { - return &MockImpl{} -} - -// Encrypt is a mock method. -func (m *MockImpl) Encrypt(ctx context.Context, content string) (string, error) { - returnVals := m.Called(ctx, content) - return returnVals.String(0), returnVals.Error(1) -} - -// Decrypt is a mock method. -func (m *MockImpl) Decrypt(ctx context.Context, content string) (string, error) { - returnVals := m.Called(ctx, content) - return returnVals.String(0), returnVals.Error(1) -} diff --git a/database/database_mock.go b/database/database_mock.go deleted file mode 100644 index afcc010..0000000 --- a/database/database_mock.go +++ /dev/null @@ -1,155 +0,0 @@ -package database - -import ( - "context" - "database/sql" - "time" - - "github.com/stretchr/testify/mock" -) - -// NewMockDatabase builds a mock database. -func NewMockDatabase() *MockDatabase { - return &MockDatabase{} -} - -// MockDatabase is our mock database structure. Note, when using this in tests, you must directly access the type name of all the implicit fields. -// So `mockDB.On(reflection.GetMethodName(mockDB.GetUserByUsername)...)` is destined to fail, whereas `mockDB.UserDataManagerMock.On(reflection.GetMethodName(UserDataManagerMock.GetUserByUsername)...)` would do what you want it to do. -type MockDatabase struct { - mock.Mock -} - -// Migrate satisfies the DataManager interface. -func (m *MockDatabase) Migrate(ctx context.Context) error { - return m.Called(ctx).Error(0) -} - -// Close satisfies the DataManager interface. -func (m *MockDatabase) Close() { - m.Called() -} - -// DB satisfies the DataManager interface. -func (m *MockDatabase) DB() *sql.DB { - return m.Called().Get(0).(*sql.DB) -} - -// ReadDB satisfies the DataManager interface. -func (m *MockDatabase) ReadDB() *sql.DB { - return m.Called().Get(0).(*sql.DB) -} - -// WriteDB satisfies the DataManager interface. -func (m *MockDatabase) WriteDB() *sql.DB { - return m.Called().Get(0).(*sql.DB) -} - -// IsReady satisfies the DataManager interface. -func (m *MockDatabase) IsReady(ctx context.Context) (ready bool) { - return m.Called(ctx).Bool(0) -} - -// BeginTx satisfies the DataManager interface. -func (m *MockDatabase) BeginTx(ctx context.Context, options *sql.TxOptions) (*sql.Tx, error) { - args := m.Called(ctx, options) - return args.Get(0).(*sql.Tx), args.Error(1) -} - -var _ ResultIterator = (*MockResultIterator)(nil) - -// MockResultIterator is our mock sql.Rows structure. -type MockResultIterator struct { - mock.Mock -} - -// Scan satisfies the ResultIterator interface. -func (m *MockResultIterator) Scan(dest ...any) error { - return m.Called(dest...).Error(0) -} - -// Next satisfies the ResultIterator interface. -func (m *MockResultIterator) Next() bool { - return m.Called().Bool(0) -} - -// Err satisfies the ResultIterator interface. -func (m *MockResultIterator) Err() error { - return m.Called().Error(0) -} - -// Close satisfies the ResultIterator interface. -func (m *MockResultIterator) Close() error { - return m.Called().Error(0) -} - -// MockSQLResult mocks a sql.Result. -type MockSQLResult struct { - mock.Mock -} - -// LastInsertId implements our interface. -func (m *MockSQLResult) LastInsertId() (int64, error) { - args := m.Called() - return args.Get(0).(int64), args.Error(1) -} - -// RowsAffected implements our interface. -func (m *MockSQLResult) RowsAffected() (int64, error) { - args := m.Called() - return args.Get(0).(int64), args.Error(1) -} - -var _ SQLQueryExecutor = (*MockQueryExecutor)(nil) - -// MockQueryExecutor mocks a sql.Tx|DB. -type MockQueryExecutor struct { - mock.Mock -} - -// ExecContext is a mock function. -func (m *MockQueryExecutor) ExecContext(ctx context.Context, query string, queryArgs ...any) (sql.Result, error) { - args := m.Called(ctx, query, queryArgs) - return args.Get(0).(sql.Result), args.Error(1) -} - -// PrepareContext is a mock function. -func (m *MockQueryExecutor) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - args := m.Called(ctx, query) - return args.Get(0).(*sql.Stmt), args.Error(1) -} - -// QueryContext is a mock function. -func (m *MockQueryExecutor) QueryContext(ctx context.Context, query string, queryArgs ...any) (*sql.Rows, error) { - args := m.Called(ctx, query, queryArgs) - return args.Get(0).(*sql.Rows), args.Error(1) -} - -// QueryRowContext is a mock function. -func (m *MockQueryExecutor) QueryRowContext(ctx context.Context, query string, queryArgs ...any) *sql.Row { - args := m.Called(ctx, query, queryArgs) - return args.Get(0).(*sql.Row) -} - -type MockClient struct { - mock.Mock -} - -func (m *MockClient) ReadDB() *sql.DB { - return nil -} - -func (m *MockClient) WriteDB() *sql.DB { - return nil -} - -func (m *MockClient) Close() error { - return m.Called().Error(0) -} - -func (m *MockClient) CurrentTime() time.Time { - return m.Called().Get(0).(time.Time) -} - -func (m *MockClient) RollbackTransaction(ctx context.Context, tx SQLQueryExecutorAndTransactionManager) { - m.Called(ctx, tx) -} diff --git a/database/database_mock_test.go b/database/database_mock_test.go deleted file mode 100644 index d343eb4..0000000 --- a/database/database_mock_test.go +++ /dev/null @@ -1,322 +0,0 @@ -package database - -import ( - "context" - "database/sql" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestNewMockDatabase(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - assert.NotNil(t, NewMockDatabase()) - }) -} - -func TestMockDatabase_Migrate(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := NewMockDatabase() - m.On("Migrate", mock.Anything).Return(nil) - - assert.NoError(t, m.Migrate(context.Background())) - }) -} - -func TestMockDatabase_Close(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := NewMockDatabase() - m.On("Close").Return() - - m.Close() - m.AssertExpectations(t) - }) -} - -func TestMockDatabase_DB(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := NewMockDatabase() - m.On("DB").Return((*sql.DB)(nil)) - - assert.Nil(t, m.DB()) - }) -} - -func TestMockDatabase_ReadDB(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := NewMockDatabase() - m.On("ReadDB").Return((*sql.DB)(nil)) - - assert.Nil(t, m.ReadDB()) - }) -} - -func TestMockDatabase_WriteDB(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := NewMockDatabase() - m.On("WriteDB").Return((*sql.DB)(nil)) - - assert.Nil(t, m.WriteDB()) - }) -} - -func TestMockDatabase_IsReady(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := NewMockDatabase() - m.On("IsReady", mock.Anything).Return(true) - - assert.True(t, m.IsReady(context.Background())) - }) -} - -func TestMockDatabase_BeginTx(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := NewMockDatabase() - m.On("BeginTx", mock.Anything, mock.Anything).Return((*sql.Tx)(nil), nil) - - tx, err := m.BeginTx(context.Background(), nil) - assert.NoError(t, err) - assert.Nil(t, tx) - }) -} - -func TestMockResultIterator_Scan(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockResultIterator{} - m.On("Scan").Return(nil) - - assert.NoError(t, m.Scan()) - }) -} - -func TestMockResultIterator_Next(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockResultIterator{} - m.On("Next").Return(true) - - assert.True(t, m.Next()) - }) -} - -func TestMockResultIterator_Err(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockResultIterator{} - m.On("Err").Return(nil) - - assert.NoError(t, m.Err()) - }) -} - -func TestMockResultIterator_Close(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockResultIterator{} - m.On("Close").Return(nil) - - assert.NoError(t, m.Close()) - }) -} - -func TestMockSQLResult_LastInsertId(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockSQLResult{} - m.On("LastInsertId").Return(int64(1), nil) - - id, err := m.LastInsertId() - assert.Equal(t, int64(1), id) - assert.NoError(t, err) - }) -} - -func TestMockSQLResult_RowsAffected(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockSQLResult{} - m.On("RowsAffected").Return(int64(5), nil) - - count, err := m.RowsAffected() - assert.Equal(t, int64(5), count) - assert.NoError(t, err) - }) -} - -func TestMockQueryExecutor_ExecContext(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - mockResult := &MockSQLResult{} - m := &MockQueryExecutor{} - m.On("ExecContext", mock.Anything, mock.Anything, mock.Anything).Return(mockResult, nil) - - result, err := m.ExecContext(context.Background(), "SELECT 1") - assert.NoError(t, err) - assert.NotNil(t, result) - }) -} - -func TestMockQueryExecutor_PrepareContext(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockQueryExecutor{} - m.On("PrepareContext", mock.Anything, mock.Anything).Return((*sql.Stmt)(nil), nil) - - stmt, err := m.PrepareContext(context.Background(), "SELECT 1") - assert.NoError(t, err) - assert.Nil(t, stmt) - }) -} - -func TestMockQueryExecutor_QueryContext(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockQueryExecutor{} - m.On("QueryContext", mock.Anything, mock.Anything, mock.Anything).Return((*sql.Rows)(nil), nil) - - rows, err := m.QueryContext(context.Background(), "SELECT 1") - assert.NoError(t, err) - assert.Nil(t, rows) - }) -} - -func TestMockQueryExecutor_QueryRowContext(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockQueryExecutor{} - m.On("QueryRowContext", mock.Anything, mock.Anything, mock.Anything).Return((*sql.Row)(nil)) - - row := m.QueryRowContext(context.Background(), "SELECT 1") - assert.Nil(t, row) - }) -} - -func TestMockClient_ReadDB(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockClient{} - - assert.Nil(t, m.ReadDB()) - }) -} - -func TestMockClient_WriteDB(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockClient{} - - assert.Nil(t, m.WriteDB()) - }) -} - -func TestMockClient_Close(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockClient{} - m.On("Close").Return(nil) - - assert.NoError(t, m.Close()) - }) -} - -func TestMockClient_CurrentTime(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - now := time.Now() - m := &MockClient{} - m.On("CurrentTime").Return(now) - - assert.Equal(t, now, m.CurrentTime()) - }) -} - -func TestMockClient_RollbackTransaction(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &MockClient{} - m.On("RollbackTransaction", mock.Anything, mock.Anything).Return() - - m.RollbackTransaction(context.Background(), nil) - m.AssertExpectations(t) - }) -} diff --git a/database/mock/database_mock.go b/database/mock/database_mock.go new file mode 100644 index 0000000..c1cb1fb --- /dev/null +++ b/database/mock/database_mock.go @@ -0,0 +1,650 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mockdatabase + +import ( + "context" + "database/sql" + "sync" + "time" + + "github.com/verygoodsoftwarenotvirus/platform/v5/database" +) + +// Ensure, that ClientMock does implement database.Client. +// If this is not the case, regenerate this file with moq. +var _ database.Client = &ClientMock{} + +// ClientMock is a mock implementation of database.Client. +// +// func TestSomethingThatUsesClient(t *testing.T) { +// +// // make and configure a mocked database.Client +// mockedClient := &ClientMock{ +// CloseFunc: func() error { +// panic("mock out the Close method") +// }, +// CurrentTimeFunc: func() time.Time { +// panic("mock out the CurrentTime method") +// }, +// ReadDBFunc: func() *sql.DB { +// panic("mock out the ReadDB method") +// }, +// RollbackTransactionFunc: func(ctx context.Context, tx database.SQLQueryExecutorAndTransactionManager) { +// panic("mock out the RollbackTransaction method") +// }, +// WriteDBFunc: func() *sql.DB { +// panic("mock out the WriteDB method") +// }, +// } +// +// // use mockedClient in code that requires database.Client +// // and then make assertions. +// +// } +type ClientMock struct { + // CloseFunc mocks the Close method. + CloseFunc func() error + + // CurrentTimeFunc mocks the CurrentTime method. + CurrentTimeFunc func() time.Time + + // ReadDBFunc mocks the ReadDB method. + ReadDBFunc func() *sql.DB + + // RollbackTransactionFunc mocks the RollbackTransaction method. + RollbackTransactionFunc func(ctx context.Context, tx database.SQLQueryExecutorAndTransactionManager) + + // WriteDBFunc mocks the WriteDB method. + WriteDBFunc func() *sql.DB + + // calls tracks calls to the methods. + calls struct { + // Close holds details about calls to the Close method. + Close []struct { + } + // CurrentTime holds details about calls to the CurrentTime method. + CurrentTime []struct { + } + // ReadDB holds details about calls to the ReadDB method. + ReadDB []struct { + } + // RollbackTransaction holds details about calls to the RollbackTransaction method. + RollbackTransaction []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Tx is the tx argument value. + Tx database.SQLQueryExecutorAndTransactionManager + } + // WriteDB holds details about calls to the WriteDB method. + WriteDB []struct { + } + } + lockClose sync.RWMutex + lockCurrentTime sync.RWMutex + lockReadDB sync.RWMutex + lockRollbackTransaction sync.RWMutex + lockWriteDB sync.RWMutex +} + +// Close calls CloseFunc. +func (mock *ClientMock) Close() error { + if mock.CloseFunc == nil { + panic("ClientMock.CloseFunc: method is nil but Client.Close was just called") + } + callInfo := struct { + }{} + mock.lockClose.Lock() + mock.calls.Close = append(mock.calls.Close, callInfo) + mock.lockClose.Unlock() + return mock.CloseFunc() +} + +// CloseCalls gets all the calls that were made to Close. +// Check the length with: +// +// len(mockedClient.CloseCalls()) +func (mock *ClientMock) CloseCalls() []struct { +} { + var calls []struct { + } + mock.lockClose.RLock() + calls = mock.calls.Close + mock.lockClose.RUnlock() + return calls +} + +// CurrentTime calls CurrentTimeFunc. +func (mock *ClientMock) CurrentTime() time.Time { + if mock.CurrentTimeFunc == nil { + panic("ClientMock.CurrentTimeFunc: method is nil but Client.CurrentTime was just called") + } + callInfo := struct { + }{} + mock.lockCurrentTime.Lock() + mock.calls.CurrentTime = append(mock.calls.CurrentTime, callInfo) + mock.lockCurrentTime.Unlock() + return mock.CurrentTimeFunc() +} + +// CurrentTimeCalls gets all the calls that were made to CurrentTime. +// Check the length with: +// +// len(mockedClient.CurrentTimeCalls()) +func (mock *ClientMock) CurrentTimeCalls() []struct { +} { + var calls []struct { + } + mock.lockCurrentTime.RLock() + calls = mock.calls.CurrentTime + mock.lockCurrentTime.RUnlock() + return calls +} + +// ReadDB calls ReadDBFunc. +func (mock *ClientMock) ReadDB() *sql.DB { + if mock.ReadDBFunc == nil { + panic("ClientMock.ReadDBFunc: method is nil but Client.ReadDB was just called") + } + callInfo := struct { + }{} + mock.lockReadDB.Lock() + mock.calls.ReadDB = append(mock.calls.ReadDB, callInfo) + mock.lockReadDB.Unlock() + return mock.ReadDBFunc() +} + +// ReadDBCalls gets all the calls that were made to ReadDB. +// Check the length with: +// +// len(mockedClient.ReadDBCalls()) +func (mock *ClientMock) ReadDBCalls() []struct { +} { + var calls []struct { + } + mock.lockReadDB.RLock() + calls = mock.calls.ReadDB + mock.lockReadDB.RUnlock() + return calls +} + +// RollbackTransaction calls RollbackTransactionFunc. +func (mock *ClientMock) RollbackTransaction(ctx context.Context, tx database.SQLQueryExecutorAndTransactionManager) { + if mock.RollbackTransactionFunc == nil { + panic("ClientMock.RollbackTransactionFunc: method is nil but Client.RollbackTransaction was just called") + } + callInfo := struct { + Ctx context.Context + Tx database.SQLQueryExecutorAndTransactionManager + }{ + Ctx: ctx, + Tx: tx, + } + mock.lockRollbackTransaction.Lock() + mock.calls.RollbackTransaction = append(mock.calls.RollbackTransaction, callInfo) + mock.lockRollbackTransaction.Unlock() + mock.RollbackTransactionFunc(ctx, tx) +} + +// RollbackTransactionCalls gets all the calls that were made to RollbackTransaction. +// Check the length with: +// +// len(mockedClient.RollbackTransactionCalls()) +func (mock *ClientMock) RollbackTransactionCalls() []struct { + Ctx context.Context + Tx database.SQLQueryExecutorAndTransactionManager +} { + var calls []struct { + Ctx context.Context + Tx database.SQLQueryExecutorAndTransactionManager + } + mock.lockRollbackTransaction.RLock() + calls = mock.calls.RollbackTransaction + mock.lockRollbackTransaction.RUnlock() + return calls +} + +// WriteDB calls WriteDBFunc. +func (mock *ClientMock) WriteDB() *sql.DB { + if mock.WriteDBFunc == nil { + panic("ClientMock.WriteDBFunc: method is nil but Client.WriteDB was just called") + } + callInfo := struct { + }{} + mock.lockWriteDB.Lock() + mock.calls.WriteDB = append(mock.calls.WriteDB, callInfo) + mock.lockWriteDB.Unlock() + return mock.WriteDBFunc() +} + +// WriteDBCalls gets all the calls that were made to WriteDB. +// Check the length with: +// +// len(mockedClient.WriteDBCalls()) +func (mock *ClientMock) WriteDBCalls() []struct { +} { + var calls []struct { + } + mock.lockWriteDB.RLock() + calls = mock.calls.WriteDB + mock.lockWriteDB.RUnlock() + return calls +} + +// Ensure, that ResultIteratorMock does implement database.ResultIterator. +// If this is not the case, regenerate this file with moq. +var _ database.ResultIterator = &ResultIteratorMock{} + +// ResultIteratorMock is a mock implementation of database.ResultIterator. +// +// func TestSomethingThatUsesResultIterator(t *testing.T) { +// +// // make and configure a mocked database.ResultIterator +// mockedResultIterator := &ResultIteratorMock{ +// CloseFunc: func() error { +// panic("mock out the Close method") +// }, +// ErrFunc: func() error { +// panic("mock out the Err method") +// }, +// NextFunc: func() bool { +// panic("mock out the Next method") +// }, +// ScanFunc: func(dest ...any) error { +// panic("mock out the Scan method") +// }, +// } +// +// // use mockedResultIterator in code that requires database.ResultIterator +// // and then make assertions. +// +// } +type ResultIteratorMock struct { + // CloseFunc mocks the Close method. + CloseFunc func() error + + // ErrFunc mocks the Err method. + ErrFunc func() error + + // NextFunc mocks the Next method. + NextFunc func() bool + + // ScanFunc mocks the Scan method. + ScanFunc func(dest ...any) error + + // calls tracks calls to the methods. + calls struct { + // Close holds details about calls to the Close method. + Close []struct { + } + // Err holds details about calls to the Err method. + Err []struct { + } + // Next holds details about calls to the Next method. + Next []struct { + } + // Scan holds details about calls to the Scan method. + Scan []struct { + // Dest is the dest argument value. + Dest []any + } + } + lockClose sync.RWMutex + lockErr sync.RWMutex + lockNext sync.RWMutex + lockScan sync.RWMutex +} + +// Close calls CloseFunc. +func (mock *ResultIteratorMock) Close() error { + if mock.CloseFunc == nil { + panic("ResultIteratorMock.CloseFunc: method is nil but ResultIterator.Close was just called") + } + callInfo := struct { + }{} + mock.lockClose.Lock() + mock.calls.Close = append(mock.calls.Close, callInfo) + mock.lockClose.Unlock() + return mock.CloseFunc() +} + +// CloseCalls gets all the calls that were made to Close. +// Check the length with: +// +// len(mockedResultIterator.CloseCalls()) +func (mock *ResultIteratorMock) CloseCalls() []struct { +} { + var calls []struct { + } + mock.lockClose.RLock() + calls = mock.calls.Close + mock.lockClose.RUnlock() + return calls +} + +// Err calls ErrFunc. +func (mock *ResultIteratorMock) Err() error { + if mock.ErrFunc == nil { + panic("ResultIteratorMock.ErrFunc: method is nil but ResultIterator.Err was just called") + } + callInfo := struct { + }{} + mock.lockErr.Lock() + mock.calls.Err = append(mock.calls.Err, callInfo) + mock.lockErr.Unlock() + return mock.ErrFunc() +} + +// ErrCalls gets all the calls that were made to Err. +// Check the length with: +// +// len(mockedResultIterator.ErrCalls()) +func (mock *ResultIteratorMock) ErrCalls() []struct { +} { + var calls []struct { + } + mock.lockErr.RLock() + calls = mock.calls.Err + mock.lockErr.RUnlock() + return calls +} + +// Next calls NextFunc. +func (mock *ResultIteratorMock) Next() bool { + if mock.NextFunc == nil { + panic("ResultIteratorMock.NextFunc: method is nil but ResultIterator.Next was just called") + } + callInfo := struct { + }{} + mock.lockNext.Lock() + mock.calls.Next = append(mock.calls.Next, callInfo) + mock.lockNext.Unlock() + return mock.NextFunc() +} + +// NextCalls gets all the calls that were made to Next. +// Check the length with: +// +// len(mockedResultIterator.NextCalls()) +func (mock *ResultIteratorMock) NextCalls() []struct { +} { + var calls []struct { + } + mock.lockNext.RLock() + calls = mock.calls.Next + mock.lockNext.RUnlock() + return calls +} + +// Scan calls ScanFunc. +func (mock *ResultIteratorMock) Scan(dest ...any) error { + if mock.ScanFunc == nil { + panic("ResultIteratorMock.ScanFunc: method is nil but ResultIterator.Scan was just called") + } + callInfo := struct { + Dest []any + }{ + Dest: dest, + } + mock.lockScan.Lock() + mock.calls.Scan = append(mock.calls.Scan, callInfo) + mock.lockScan.Unlock() + return mock.ScanFunc(dest...) +} + +// ScanCalls gets all the calls that were made to Scan. +// Check the length with: +// +// len(mockedResultIterator.ScanCalls()) +func (mock *ResultIteratorMock) ScanCalls() []struct { + Dest []any +} { + var calls []struct { + Dest []any + } + mock.lockScan.RLock() + calls = mock.calls.Scan + mock.lockScan.RUnlock() + return calls +} + +// Ensure, that SQLQueryExecutorMock does implement database.SQLQueryExecutor. +// If this is not the case, regenerate this file with moq. +var _ database.SQLQueryExecutor = &SQLQueryExecutorMock{} + +// SQLQueryExecutorMock is a mock implementation of database.SQLQueryExecutor. +// +// func TestSomethingThatUsesSQLQueryExecutor(t *testing.T) { +// +// // make and configure a mocked database.SQLQueryExecutor +// mockedSQLQueryExecutor := &SQLQueryExecutorMock{ +// ExecContextFunc: func(ctx context.Context, query string, args ...any) (sql.Result, error) { +// panic("mock out the ExecContext method") +// }, +// PrepareContextFunc: func(contextMoqParam context.Context, s string) (*sql.Stmt, error) { +// panic("mock out the PrepareContext method") +// }, +// QueryContextFunc: func(ctx context.Context, query string, args ...any) (*sql.Rows, error) { +// panic("mock out the QueryContext method") +// }, +// QueryRowContextFunc: func(ctx context.Context, query string, args ...any) *sql.Row { +// panic("mock out the QueryRowContext method") +// }, +// } +// +// // use mockedSQLQueryExecutor in code that requires database.SQLQueryExecutor +// // and then make assertions. +// +// } +type SQLQueryExecutorMock struct { + // ExecContextFunc mocks the ExecContext method. + ExecContextFunc func(ctx context.Context, query string, args ...any) (sql.Result, error) + + // PrepareContextFunc mocks the PrepareContext method. + PrepareContextFunc func(contextMoqParam context.Context, s string) (*sql.Stmt, error) + + // QueryContextFunc mocks the QueryContext method. + QueryContextFunc func(ctx context.Context, query string, args ...any) (*sql.Rows, error) + + // QueryRowContextFunc mocks the QueryRowContext method. + QueryRowContextFunc func(ctx context.Context, query string, args ...any) *sql.Row + + // calls tracks calls to the methods. + calls struct { + // ExecContext holds details about calls to the ExecContext method. + ExecContext []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Query is the query argument value. + Query string + // Args is the args argument value. + Args []any + } + // PrepareContext holds details about calls to the PrepareContext method. + PrepareContext []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // S is the s argument value. + S string + } + // QueryContext holds details about calls to the QueryContext method. + QueryContext []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Query is the query argument value. + Query string + // Args is the args argument value. + Args []any + } + // QueryRowContext holds details about calls to the QueryRowContext method. + QueryRowContext []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Query is the query argument value. + Query string + // Args is the args argument value. + Args []any + } + } + lockExecContext sync.RWMutex + lockPrepareContext sync.RWMutex + lockQueryContext sync.RWMutex + lockQueryRowContext sync.RWMutex +} + +// ExecContext calls ExecContextFunc. +func (mock *SQLQueryExecutorMock) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { + if mock.ExecContextFunc == nil { + panic("SQLQueryExecutorMock.ExecContextFunc: method is nil but SQLQueryExecutor.ExecContext was just called") + } + callInfo := struct { + Ctx context.Context + Query string + Args []any + }{ + Ctx: ctx, + Query: query, + Args: args, + } + mock.lockExecContext.Lock() + mock.calls.ExecContext = append(mock.calls.ExecContext, callInfo) + mock.lockExecContext.Unlock() + return mock.ExecContextFunc(ctx, query, args...) +} + +// ExecContextCalls gets all the calls that were made to ExecContext. +// Check the length with: +// +// len(mockedSQLQueryExecutor.ExecContextCalls()) +func (mock *SQLQueryExecutorMock) ExecContextCalls() []struct { + Ctx context.Context + Query string + Args []any +} { + var calls []struct { + Ctx context.Context + Query string + Args []any + } + mock.lockExecContext.RLock() + calls = mock.calls.ExecContext + mock.lockExecContext.RUnlock() + return calls +} + +// PrepareContext calls PrepareContextFunc. +func (mock *SQLQueryExecutorMock) PrepareContext(contextMoqParam context.Context, s string) (*sql.Stmt, error) { + if mock.PrepareContextFunc == nil { + panic("SQLQueryExecutorMock.PrepareContextFunc: method is nil but SQLQueryExecutor.PrepareContext was just called") + } + callInfo := struct { + ContextMoqParam context.Context + S string + }{ + ContextMoqParam: contextMoqParam, + S: s, + } + mock.lockPrepareContext.Lock() + mock.calls.PrepareContext = append(mock.calls.PrepareContext, callInfo) + mock.lockPrepareContext.Unlock() + return mock.PrepareContextFunc(contextMoqParam, s) +} + +// PrepareContextCalls gets all the calls that were made to PrepareContext. +// Check the length with: +// +// len(mockedSQLQueryExecutor.PrepareContextCalls()) +func (mock *SQLQueryExecutorMock) PrepareContextCalls() []struct { + ContextMoqParam context.Context + S string +} { + var calls []struct { + ContextMoqParam context.Context + S string + } + mock.lockPrepareContext.RLock() + calls = mock.calls.PrepareContext + mock.lockPrepareContext.RUnlock() + return calls +} + +// QueryContext calls QueryContextFunc. +func (mock *SQLQueryExecutorMock) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + if mock.QueryContextFunc == nil { + panic("SQLQueryExecutorMock.QueryContextFunc: method is nil but SQLQueryExecutor.QueryContext was just called") + } + callInfo := struct { + Ctx context.Context + Query string + Args []any + }{ + Ctx: ctx, + Query: query, + Args: args, + } + mock.lockQueryContext.Lock() + mock.calls.QueryContext = append(mock.calls.QueryContext, callInfo) + mock.lockQueryContext.Unlock() + return mock.QueryContextFunc(ctx, query, args...) +} + +// QueryContextCalls gets all the calls that were made to QueryContext. +// Check the length with: +// +// len(mockedSQLQueryExecutor.QueryContextCalls()) +func (mock *SQLQueryExecutorMock) QueryContextCalls() []struct { + Ctx context.Context + Query string + Args []any +} { + var calls []struct { + Ctx context.Context + Query string + Args []any + } + mock.lockQueryContext.RLock() + calls = mock.calls.QueryContext + mock.lockQueryContext.RUnlock() + return calls +} + +// QueryRowContext calls QueryRowContextFunc. +func (mock *SQLQueryExecutorMock) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row { + if mock.QueryRowContextFunc == nil { + panic("SQLQueryExecutorMock.QueryRowContextFunc: method is nil but SQLQueryExecutor.QueryRowContext was just called") + } + callInfo := struct { + Ctx context.Context + Query string + Args []any + }{ + Ctx: ctx, + Query: query, + Args: args, + } + mock.lockQueryRowContext.Lock() + mock.calls.QueryRowContext = append(mock.calls.QueryRowContext, callInfo) + mock.lockQueryRowContext.Unlock() + return mock.QueryRowContextFunc(ctx, query, args...) +} + +// QueryRowContextCalls gets all the calls that were made to QueryRowContext. +// Check the length with: +// +// len(mockedSQLQueryExecutor.QueryRowContextCalls()) +func (mock *SQLQueryExecutorMock) QueryRowContextCalls() []struct { + Ctx context.Context + Query string + Args []any +} { + var calls []struct { + Ctx context.Context + Query string + Args []any + } + mock.lockQueryRowContext.RLock() + calls = mock.calls.QueryRowContext + mock.lockQueryRowContext.RUnlock() + return calls +} diff --git a/database/mock/doc.go b/database/mock/doc.go new file mode 100644 index 0000000..55c7d7f --- /dev/null +++ b/database/mock/doc.go @@ -0,0 +1,8 @@ +/* +Package mockdatabase provides moq-generated mocks for the database package. +*/ +package mockdatabase + +// Regenerate the moq mocks via `go generate ./database/mock/`. + +//go:generate go tool github.com/matryer/moq -out database_mock.go -pkg mockdatabase -rm -fmt goimports .. Client:ClientMock ResultIterator:ResultIteratorMock SQLQueryExecutor:SQLQueryExecutorMock diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index ce75430..1f5f3bf 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -13,7 +13,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -62,15 +61,7 @@ func (c *testClientConfig) GetConnMaxLifetime() time.Duration { return 30 * time.Minute } -type sqlmockExpecterWrapper struct { - sqlmock.Sqlmock -} - -func (e *sqlmockExpecterWrapper) AssertExpectations(t mock.TestingT) bool { - return assert.NoError(t, e.ExpectationsWereMet(), "not all database expectations were met") -} - -func buildTestClient(t *testing.T) (*Client, *sqlmockExpecterWrapper) { +func buildTestClient(t *testing.T) (*Client, sqlmock.Sqlmock) { t.Helper() fakeDB, sqlMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) @@ -88,7 +79,7 @@ func buildTestClient(t *testing.T) (*Client, *sqlmockExpecterWrapper) { tracer: tracing.NewTracerForTest("test"), } - return c, &sqlmockExpecterWrapper{Sqlmock: sqlMock} + return c, sqlMock } // end helper funcs diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index f0b42e4..544d76c 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -13,7 +13,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -62,15 +61,7 @@ func (c *testClientConfig) GetConnMaxLifetime() time.Duration { return 30 * time.Minute } -type sqlmockExpecterWrapper struct { - sqlmock.Sqlmock -} - -func (e *sqlmockExpecterWrapper) AssertExpectations(t mock.TestingT) bool { - return assert.NoError(t, e.ExpectationsWereMet(), "not all database expectations were met") -} - -func buildTestClient(t *testing.T) (*Client, *sqlmockExpecterWrapper) { +func buildTestClient(t *testing.T) (*Client, sqlmock.Sqlmock) { t.Helper() fakeDB, sqlMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) @@ -88,7 +79,7 @@ func buildTestClient(t *testing.T) (*Client, *sqlmockExpecterWrapper) { tracer: tracing.NewTracerForTest("test"), } - return c, &sqlmockExpecterWrapper{Sqlmock: sqlMock} + return c, sqlMock } // end helper funcs diff --git a/database/sqlite/sqlite_test.go b/database/sqlite/sqlite_test.go index af74379..dac4e86 100644 --- a/database/sqlite/sqlite_test.go +++ b/database/sqlite/sqlite_test.go @@ -13,7 +13,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -62,15 +61,7 @@ func (c *testClientConfig) GetConnMaxLifetime() time.Duration { return 30 * time.Minute } -type sqlmockExpecterWrapper struct { - sqlmock.Sqlmock -} - -func (e *sqlmockExpecterWrapper) AssertExpectations(t mock.TestingT) bool { - return assert.NoError(t, e.ExpectationsWereMet(), "not all database expectations were met") -} - -func buildTestClient(t *testing.T) (*Client, *sqlmockExpecterWrapper) { +func buildTestClient(t *testing.T) (*Client, sqlmock.Sqlmock) { t.Helper() fakeDB, sqlMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) @@ -88,7 +79,7 @@ func buildTestClient(t *testing.T) (*Client, *sqlmockExpecterWrapper) { tracer: tracing.NewTracerForTest("test"), } - return c, &sqlmockExpecterWrapper{Sqlmock: sqlMock} + return c, sqlMock } // end helper funcs diff --git a/distributedlock/config/config_test.go b/distributedlock/config/config_test.go index 394b36d..856df70 100644 --- a/distributedlock/config/config_test.go +++ b/distributedlock/config/config_test.go @@ -17,6 +17,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" @@ -232,9 +233,12 @@ func TestProvideLocker(T *testing.T) { }, } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", "dlock-breaker_circuit_breaker_tripped", []metric.Int64CounterOption(nil)). - Return(&mockmetrics.Int64Counter{}, fmt.Errorf("counter init failure")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "dlock-breaker_circuit_breaker_tripped", counterName) + return &mockmetrics.Int64CounterMock{}, fmt.Errorf("counter init failure") + }, + } l, err := ProvideLocker( t.Context(), @@ -247,6 +251,7 @@ func TestProvideLocker(T *testing.T) { require.Error(t, err) assert.Nil(t, l) assert.Contains(t, err.Error(), "distributedlock circuit breaker") - mp.AssertExpectations(t) + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/distributedlock/mock/doc.go b/distributedlock/mock/doc.go new file mode 100644 index 0000000..ff266fc --- /dev/null +++ b/distributedlock/mock/doc.go @@ -0,0 +1,9 @@ +// Package mock provides mock implementations of the distributedlock package's +// interfaces. Both the hand-written testify-based Locker/Lock types and the +// moq-generated LockerMock/LockMock types live here during the testify → moq +// migration. New test code should prefer the moq-generated types. +package mock + +// Regenerate the moq mocks via `go generate ./distributedlock/mock/`. + +//go:generate go tool github.com/matryer/moq -out locker_mock.go -pkg mock -rm -fmt goimports .. Locker:LockerMock Lock:LockMock diff --git a/distributedlock/mock/locker.go b/distributedlock/mock/locker.go deleted file mode 100644 index 865810a..0000000 --- a/distributedlock/mock/locker.go +++ /dev/null @@ -1,68 +0,0 @@ -package mock - -import ( - "context" - "time" - - "github.com/verygoodsoftwarenotvirus/platform/v5/distributedlock" - - "github.com/stretchr/testify/mock" -) - -var ( - _ distributedlock.Locker = (*Locker)(nil) - _ distributedlock.Lock = (*Lock)(nil) -) - -// Locker is a testify-backed mock of distributedlock.Locker. -type Locker struct { - mock.Mock -} - -// Acquire implements distributedlock.Locker. -func (m *Locker) Acquire(ctx context.Context, key string, ttl time.Duration) (distributedlock.Lock, error) { - args := m.Called(ctx, key, ttl) - if v := args.Get(0); v != nil { - return v.(distributedlock.Lock), args.Error(1) - } - return nil, args.Error(1) -} - -// Ping implements distributedlock.Locker. -func (m *Locker) Ping(ctx context.Context) error { - return m.Called(ctx).Error(0) -} - -// Close implements distributedlock.Locker. -func (m *Locker) Close() error { - return m.Called().Error(0) -} - -// Lock is a testify-backed mock of distributedlock.Lock. -type Lock struct { - mock.Mock -} - -// Key implements distributedlock.Lock. -func (m *Lock) Key() string { - return m.Called().String(0) -} - -// TTL implements distributedlock.Lock. -func (m *Lock) TTL() time.Duration { - args := m.Called() - if v, ok := args.Get(0).(time.Duration); ok { - return v - } - return 0 -} - -// Release implements distributedlock.Lock. -func (m *Lock) Release(ctx context.Context) error { - return m.Called(ctx).Error(0) -} - -// Refresh implements distributedlock.Lock. -func (m *Lock) Refresh(ctx context.Context, ttl time.Duration) error { - return m.Called(ctx, ttl).Error(0) -} diff --git a/distributedlock/mock/locker_mock.go b/distributedlock/mock/locker_mock.go new file mode 100644 index 0000000..713f11a --- /dev/null +++ b/distributedlock/mock/locker_mock.go @@ -0,0 +1,361 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mock + +import ( + "context" + "sync" + "time" + + "github.com/verygoodsoftwarenotvirus/platform/v5/distributedlock" +) + +// Ensure, that LockerMock does implement distributedlock.Locker. +// If this is not the case, regenerate this file with moq. +var _ distributedlock.Locker = &LockerMock{} + +// LockerMock is a mock implementation of distributedlock.Locker. +// +// func TestSomethingThatUsesLocker(t *testing.T) { +// +// // make and configure a mocked distributedlock.Locker +// mockedLocker := &LockerMock{ +// AcquireFunc: func(ctx context.Context, key string, ttl time.Duration) (distributedlock.Lock, error) { +// panic("mock out the Acquire method") +// }, +// CloseFunc: func() error { +// panic("mock out the Close method") +// }, +// PingFunc: func(ctx context.Context) error { +// panic("mock out the Ping method") +// }, +// } +// +// // use mockedLocker in code that requires distributedlock.Locker +// // and then make assertions. +// +// } +type LockerMock struct { + // AcquireFunc mocks the Acquire method. + AcquireFunc func(ctx context.Context, key string, ttl time.Duration) (distributedlock.Lock, error) + + // CloseFunc mocks the Close method. + CloseFunc func() error + + // PingFunc mocks the Ping method. + PingFunc func(ctx context.Context) error + + // calls tracks calls to the methods. + calls struct { + // Acquire holds details about calls to the Acquire method. + Acquire []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Key is the key argument value. + Key string + // TTL is the ttl argument value. + TTL time.Duration + } + // Close holds details about calls to the Close method. + Close []struct { + } + // Ping holds details about calls to the Ping method. + Ping []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + } + lockAcquire sync.RWMutex + lockClose sync.RWMutex + lockPing sync.RWMutex +} + +// Acquire calls AcquireFunc. +func (mock *LockerMock) Acquire(ctx context.Context, key string, ttl time.Duration) (distributedlock.Lock, error) { + if mock.AcquireFunc == nil { + panic("LockerMock.AcquireFunc: method is nil but Locker.Acquire was just called") + } + callInfo := struct { + Ctx context.Context + Key string + TTL time.Duration + }{ + Ctx: ctx, + Key: key, + TTL: ttl, + } + mock.lockAcquire.Lock() + mock.calls.Acquire = append(mock.calls.Acquire, callInfo) + mock.lockAcquire.Unlock() + return mock.AcquireFunc(ctx, key, ttl) +} + +// AcquireCalls gets all the calls that were made to Acquire. +// Check the length with: +// +// len(mockedLocker.AcquireCalls()) +func (mock *LockerMock) AcquireCalls() []struct { + Ctx context.Context + Key string + TTL time.Duration +} { + var calls []struct { + Ctx context.Context + Key string + TTL time.Duration + } + mock.lockAcquire.RLock() + calls = mock.calls.Acquire + mock.lockAcquire.RUnlock() + return calls +} + +// Close calls CloseFunc. +func (mock *LockerMock) Close() error { + if mock.CloseFunc == nil { + panic("LockerMock.CloseFunc: method is nil but Locker.Close was just called") + } + callInfo := struct { + }{} + mock.lockClose.Lock() + mock.calls.Close = append(mock.calls.Close, callInfo) + mock.lockClose.Unlock() + return mock.CloseFunc() +} + +// CloseCalls gets all the calls that were made to Close. +// Check the length with: +// +// len(mockedLocker.CloseCalls()) +func (mock *LockerMock) CloseCalls() []struct { +} { + var calls []struct { + } + mock.lockClose.RLock() + calls = mock.calls.Close + mock.lockClose.RUnlock() + return calls +} + +// Ping calls PingFunc. +func (mock *LockerMock) Ping(ctx context.Context) error { + if mock.PingFunc == nil { + panic("LockerMock.PingFunc: method is nil but Locker.Ping was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockPing.Lock() + mock.calls.Ping = append(mock.calls.Ping, callInfo) + mock.lockPing.Unlock() + return mock.PingFunc(ctx) +} + +// PingCalls gets all the calls that were made to Ping. +// Check the length with: +// +// len(mockedLocker.PingCalls()) +func (mock *LockerMock) PingCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockPing.RLock() + calls = mock.calls.Ping + mock.lockPing.RUnlock() + return calls +} + +// Ensure, that LockMock does implement distributedlock.Lock. +// If this is not the case, regenerate this file with moq. +var _ distributedlock.Lock = &LockMock{} + +// LockMock is a mock implementation of distributedlock.Lock. +// +// func TestSomethingThatUsesLock(t *testing.T) { +// +// // make and configure a mocked distributedlock.Lock +// mockedLock := &LockMock{ +// KeyFunc: func() string { +// panic("mock out the Key method") +// }, +// RefreshFunc: func(ctx context.Context, ttl time.Duration) error { +// panic("mock out the Refresh method") +// }, +// ReleaseFunc: func(ctx context.Context) error { +// panic("mock out the Release method") +// }, +// TTLFunc: func() time.Duration { +// panic("mock out the TTL method") +// }, +// } +// +// // use mockedLock in code that requires distributedlock.Lock +// // and then make assertions. +// +// } +type LockMock struct { + // KeyFunc mocks the Key method. + KeyFunc func() string + + // RefreshFunc mocks the Refresh method. + RefreshFunc func(ctx context.Context, ttl time.Duration) error + + // ReleaseFunc mocks the Release method. + ReleaseFunc func(ctx context.Context) error + + // TTLFunc mocks the TTL method. + TTLFunc func() time.Duration + + // calls tracks calls to the methods. + calls struct { + // Key holds details about calls to the Key method. + Key []struct { + } + // Refresh holds details about calls to the Refresh method. + Refresh []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // TTL is the ttl argument value. + TTL time.Duration + } + // Release holds details about calls to the Release method. + Release []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // TTL holds details about calls to the TTL method. + TTL []struct { + } + } + lockKey sync.RWMutex + lockRefresh sync.RWMutex + lockRelease sync.RWMutex + lockTTL sync.RWMutex +} + +// Key calls KeyFunc. +func (mock *LockMock) Key() string { + if mock.KeyFunc == nil { + panic("LockMock.KeyFunc: method is nil but Lock.Key was just called") + } + callInfo := struct { + }{} + mock.lockKey.Lock() + mock.calls.Key = append(mock.calls.Key, callInfo) + mock.lockKey.Unlock() + return mock.KeyFunc() +} + +// KeyCalls gets all the calls that were made to Key. +// Check the length with: +// +// len(mockedLock.KeyCalls()) +func (mock *LockMock) KeyCalls() []struct { +} { + var calls []struct { + } + mock.lockKey.RLock() + calls = mock.calls.Key + mock.lockKey.RUnlock() + return calls +} + +// Refresh calls RefreshFunc. +func (mock *LockMock) Refresh(ctx context.Context, ttl time.Duration) error { + if mock.RefreshFunc == nil { + panic("LockMock.RefreshFunc: method is nil but Lock.Refresh was just called") + } + callInfo := struct { + Ctx context.Context + TTL time.Duration + }{ + Ctx: ctx, + TTL: ttl, + } + mock.lockRefresh.Lock() + mock.calls.Refresh = append(mock.calls.Refresh, callInfo) + mock.lockRefresh.Unlock() + return mock.RefreshFunc(ctx, ttl) +} + +// RefreshCalls gets all the calls that were made to Refresh. +// Check the length with: +// +// len(mockedLock.RefreshCalls()) +func (mock *LockMock) RefreshCalls() []struct { + Ctx context.Context + TTL time.Duration +} { + var calls []struct { + Ctx context.Context + TTL time.Duration + } + mock.lockRefresh.RLock() + calls = mock.calls.Refresh + mock.lockRefresh.RUnlock() + return calls +} + +// Release calls ReleaseFunc. +func (mock *LockMock) Release(ctx context.Context) error { + if mock.ReleaseFunc == nil { + panic("LockMock.ReleaseFunc: method is nil but Lock.Release was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockRelease.Lock() + mock.calls.Release = append(mock.calls.Release, callInfo) + mock.lockRelease.Unlock() + return mock.ReleaseFunc(ctx) +} + +// ReleaseCalls gets all the calls that were made to Release. +// Check the length with: +// +// len(mockedLock.ReleaseCalls()) +func (mock *LockMock) ReleaseCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockRelease.RLock() + calls = mock.calls.Release + mock.lockRelease.RUnlock() + return calls +} + +// TTL calls TTLFunc. +func (mock *LockMock) TTL() time.Duration { + if mock.TTLFunc == nil { + panic("LockMock.TTLFunc: method is nil but Lock.TTL was just called") + } + callInfo := struct { + }{} + mock.lockTTL.Lock() + mock.calls.TTL = append(mock.calls.TTL, callInfo) + mock.lockTTL.Unlock() + return mock.TTLFunc() +} + +// TTLCalls gets all the calls that were made to TTL. +// Check the length with: +// +// len(mockedLock.TTLCalls()) +func (mock *LockMock) TTLCalls() []struct { +} { + var calls []struct { + } + mock.lockTTL.RLock() + calls = mock.calls.TTL + mock.lockTTL.RUnlock() + return calls +} diff --git a/distributedlock/postgres/postgres_test.go b/distributedlock/postgres/postgres_test.go index d899b13..31863d9 100644 --- a/distributedlock/postgres/postgres_test.go +++ b/distributedlock/postgres/postgres_test.go @@ -236,12 +236,13 @@ func TestLocker_Acquire_Unit(T *testing.T) { t.Parallel() client, _ := buildSqlmockClient(t) t.Cleanup(func() { _ = client.Close() }) - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } l := newTestLockerWithCB(t, client, cb) _, err := l.Acquire(t.Context(), "k", time.Minute) require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - cb.AssertExpectations(t) + require.NotEmpty(t, cb.CannotProceedCalls()) }) T.Run("Conn reservation failure", func(t *testing.T) { @@ -313,10 +314,14 @@ func TestLocker_Release_Unit(T *testing.T) { t.Parallel() client, mock := buildSqlmockClient(t) t.Cleanup(func() { _ = client.Close() }) - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false).Once() - cb.On("Succeeded").Return().Once() - cb.On("CannotProceed").Return(true).Once() + var cannotProceedCalls int + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { + cannotProceedCalls++ + return cannotProceedCalls > 1 // first call (Acquire) proceeds, second (Release) is blocked + }, + SucceededFunc: func() {}, + } l := newTestLockerWithCB(t, client, cb) mock.ExpectQuery(`SELECT pg_try_advisory_lock`). @@ -326,7 +331,8 @@ func TestLocker_Release_Unit(T *testing.T) { h, err := l.Acquire(t.Context(), "k", time.Minute) require.NoError(t, err) require.ErrorIs(t, h.Release(t.Context()), circuitbreaking.ErrCircuitBroken) - cb.AssertExpectations(t) + require.Len(t, cb.CannotProceedCalls(), 2) + require.Len(t, cb.SucceededCalls(), 1) }) T.Run("double release returns ErrLockNotHeld", func(t *testing.T) { @@ -375,10 +381,11 @@ func TestLocker_Release_Unit(T *testing.T) { t.Parallel() client, mock := buildSqlmockClient(t) t.Cleanup(func() { _ = client.Close() }) - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return().Once() - cb.On("Failed").Return().Once() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + FailedFunc: func() {}, + } l := newTestLockerWithCB(t, client, cb) mock.ExpectQuery(`SELECT pg_try_advisory_lock`). @@ -391,7 +398,8 @@ func TestLocker_Release_Unit(T *testing.T) { h, err := l.Acquire(t.Context(), "k", time.Minute) require.NoError(t, err) require.Error(t, h.Release(t.Context())) - cb.AssertExpectations(t) + require.Len(t, cb.SucceededCalls(), 1) + require.Len(t, cb.FailedCalls(), 1) }) } @@ -436,10 +444,14 @@ func TestLocker_Refresh_Unit(T *testing.T) { t.Parallel() client, mock := buildSqlmockClient(t) t.Cleanup(func() { _ = client.Close() }) - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false).Once() - cb.On("Succeeded").Return().Once() - cb.On("CannotProceed").Return(true).Once() + var cannotProceedCalls int + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { + cannotProceedCalls++ + return cannotProceedCalls > 1 // first call (Acquire) proceeds, second (Refresh) is blocked + }, + SucceededFunc: func() {}, + } l := newTestLockerWithCB(t, client, cb) mock.ExpectQuery(`SELECT pg_try_advisory_lock`). @@ -449,7 +461,8 @@ func TestLocker_Refresh_Unit(T *testing.T) { h, err := l.Acquire(t.Context(), "k", time.Minute) require.NoError(t, err) require.ErrorIs(t, h.Refresh(t.Context(), time.Minute), circuitbreaking.ErrCircuitBroken) - cb.AssertExpectations(t) + require.Len(t, cb.CannotProceedCalls(), 2) + require.Len(t, cb.SucceededCalls(), 1) }) T.Run("refresh after release returns ErrLockNotHeld", func(t *testing.T) { diff --git a/distributedlock/redis/redis_test.go b/distributedlock/redis/redis_test.go index 5b96adb..21cc2b3 100644 --- a/distributedlock/redis/redis_test.go +++ b/distributedlock/redis/redis_test.go @@ -289,39 +289,44 @@ func TestLocker_Acquire(T *testing.T) { T.Run("blocked by circuit breaker", func(t *testing.T) { t.Parallel() - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } l := newUnitLocker(t, &fakeRedisClient{}, cb) _, err := l.Acquire(t.Context(), "k", time.Minute) require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - cb.AssertExpectations(t) + require.NotEmpty(t, cb.CannotProceedCalls()) }) T.Run("SetNX backend error trips breaker", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{setNXErr: errors.New("redis down")} - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } l := newUnitLocker(t, fc, cb) _, err := l.Acquire(t.Context(), "k", time.Minute) require.Error(t, err) - cb.AssertExpectations(t) + require.NotEmpty(t, cb.CannotProceedCalls()) + require.NotEmpty(t, cb.FailedCalls()) }) T.Run("contention does not fail breaker", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{setNXResult: false} - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } l := newUnitLocker(t, fc, cb) _, err := l.Acquire(t.Context(), "k", time.Minute) require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) - cb.AssertExpectations(t) + require.NotEmpty(t, cb.CannotProceedCalls()) + require.NotEmpty(t, cb.SucceededCalls()) }) } @@ -352,13 +357,12 @@ func TestLocker_Release(T *testing.T) { T.Run("eval backend error trips breaker", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{setNXResult: true} - cb := &cbmock.MockCircuitBreaker{} - // Acquire path - cb.On("CannotProceed").Return(false).Once() - cb.On("Succeeded").Return().Once() - // Release path: proceed, then evalErr triggers Failed. - cb.On("CannotProceed").Return(false).Once() - cb.On("Failed").Return().Once() + // Acquire path: proceed + succeeded. Release path: proceed + failed. + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + FailedFunc: func() {}, + } l := newUnitLocker(t, fc, cb) h, err := l.Acquire(t.Context(), "k", time.Minute) @@ -366,22 +370,29 @@ func TestLocker_Release(T *testing.T) { fc.evalErr = errors.New("eval boom") require.Error(t, h.Release(t.Context())) - cb.AssertExpectations(t) + require.Len(t, cb.CannotProceedCalls(), 2) + require.Len(t, cb.SucceededCalls(), 1) + require.Len(t, cb.FailedCalls(), 1) }) T.Run("blocked by circuit breaker", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{setNXResult: true} - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false).Once() - cb.On("Succeeded").Return().Once() - cb.On("CannotProceed").Return(true).Once() + var cannotProceedCalls int + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { + cannotProceedCalls++ + return cannotProceedCalls > 1 // first call (Acquire) proceeds, second (Release) is blocked + }, + SucceededFunc: func() {}, + } l := newUnitLocker(t, fc, cb) h, err := l.Acquire(t.Context(), "k", time.Minute) require.NoError(t, err) require.ErrorIs(t, h.Release(t.Context()), circuitbreaking.ErrCircuitBroken) - cb.AssertExpectations(t) + require.Len(t, cb.CannotProceedCalls(), 2) + require.Len(t, cb.SucceededCalls(), 1) }) } @@ -427,11 +438,11 @@ func TestLocker_Refresh(T *testing.T) { T.Run("eval backend error trips breaker", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{setNXResult: true} - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false).Once() - cb.On("Succeeded").Return().Once() - cb.On("CannotProceed").Return(false).Once() - cb.On("Failed").Return().Once() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + FailedFunc: func() {}, + } l := newUnitLocker(t, fc, cb) h, err := l.Acquire(t.Context(), "k", time.Minute) @@ -439,22 +450,29 @@ func TestLocker_Refresh(T *testing.T) { fc.evalErr = errors.New("eval boom") require.Error(t, h.Refresh(t.Context(), 5*time.Minute)) - cb.AssertExpectations(t) + require.Len(t, cb.CannotProceedCalls(), 2) + require.Len(t, cb.SucceededCalls(), 1) + require.Len(t, cb.FailedCalls(), 1) }) T.Run("blocked by circuit breaker", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{setNXResult: true} - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false).Once() - cb.On("Succeeded").Return().Once() - cb.On("CannotProceed").Return(true).Once() + var cannotProceedCalls int + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { + cannotProceedCalls++ + return cannotProceedCalls > 1 // first call (Acquire) proceeds, second (Refresh) is blocked + }, + SucceededFunc: func() {}, + } l := newUnitLocker(t, fc, cb) h, err := l.Acquire(t.Context(), "k", time.Minute) require.NoError(t, err) require.ErrorIs(t, h.Refresh(t.Context(), time.Minute), circuitbreaking.ErrCircuitBroken) - cb.AssertExpectations(t) + require.Len(t, cb.CannotProceedCalls(), 2) + require.Len(t, cb.SucceededCalls(), 1) }) } diff --git a/email/config/config_test.go b/email/config/config_test.go index 9fb1147..cd20b2f 100644 --- a/email/config/config_test.go +++ b/email/config/config_test.go @@ -18,6 +18,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" @@ -235,9 +236,12 @@ func TestProvideEmailer(T *testing.T) { cfg.CircuitBreaker.ErrorRate = 50 cfg.CircuitBreaker.MinimumSampleThreshold = 10 - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", "email-breaker_circuit_breaker_tripped", []metric.Int64CounterOption(nil)). - Return(&mockmetrics.Int64Counter{}, fmt.Errorf("counter init failure")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "email-breaker_circuit_breaker_tripped", counterName) + return &mockmetrics.Int64CounterMock{}, fmt.Errorf("counter init failure") + }, + } emailer, err := ProvideEmailer( t.Context(), @@ -249,6 +253,7 @@ func TestProvideEmailer(T *testing.T) { ) require.Error(t, err) assert.Nil(t, emailer) - mp.AssertExpectations(t) + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/email/mock/doc.go b/email/mock/doc.go new file mode 100644 index 0000000..6e9fda1 --- /dev/null +++ b/email/mock/doc.go @@ -0,0 +1,9 @@ +// Package emailmock provides mock implementations of the email package's +// interfaces. Both the hand-written testify-based Emailer type and the +// moq-generated EmailerMock type live here during the testify → moq +// migration. New test code should prefer the moq-generated types. +package emailmock + +// Regenerate the moq mocks via `go generate ./email/mock/`. + +//go:generate go tool github.com/matryer/moq -out emailer_mock.go -pkg emailmock -rm -fmt goimports .. Emailer:EmailerMock diff --git a/email/mock/emailer_mock.go b/email/mock/emailer_mock.go new file mode 100644 index 0000000..14786ad --- /dev/null +++ b/email/mock/emailer_mock.go @@ -0,0 +1,83 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package emailmock + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/email" +) + +// Ensure, that EmailerMock does implement email.Emailer. +// If this is not the case, regenerate this file with moq. +var _ email.Emailer = &EmailerMock{} + +// EmailerMock is a mock implementation of email.Emailer. +// +// func TestSomethingThatUsesEmailer(t *testing.T) { +// +// // make and configure a mocked email.Emailer +// mockedEmailer := &EmailerMock{ +// SendEmailFunc: func(ctx context.Context, details *email.OutboundEmailMessage) error { +// panic("mock out the SendEmail method") +// }, +// } +// +// // use mockedEmailer in code that requires email.Emailer +// // and then make assertions. +// +// } +type EmailerMock struct { + // SendEmailFunc mocks the SendEmail method. + SendEmailFunc func(ctx context.Context, details *email.OutboundEmailMessage) error + + // calls tracks calls to the methods. + calls struct { + // SendEmail holds details about calls to the SendEmail method. + SendEmail []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Details is the details argument value. + Details *email.OutboundEmailMessage + } + } + lockSendEmail sync.RWMutex +} + +// SendEmail calls SendEmailFunc. +func (mock *EmailerMock) SendEmail(ctx context.Context, details *email.OutboundEmailMessage) error { + if mock.SendEmailFunc == nil { + panic("EmailerMock.SendEmailFunc: method is nil but Emailer.SendEmail was just called") + } + callInfo := struct { + Ctx context.Context + Details *email.OutboundEmailMessage + }{ + Ctx: ctx, + Details: details, + } + mock.lockSendEmail.Lock() + mock.calls.SendEmail = append(mock.calls.SendEmail, callInfo) + mock.lockSendEmail.Unlock() + return mock.SendEmailFunc(ctx, details) +} + +// SendEmailCalls gets all the calls that were made to SendEmail. +// Check the length with: +// +// len(mockedEmailer.SendEmailCalls()) +func (mock *EmailerMock) SendEmailCalls() []struct { + Ctx context.Context + Details *email.OutboundEmailMessage +} { + var calls []struct { + Ctx context.Context + Details *email.OutboundEmailMessage + } + mock.lockSendEmail.RLock() + calls = mock.calls.SendEmail + mock.lockSendEmail.RUnlock() + return calls +} diff --git a/email/mock/mock_emailer.go b/email/mock/mock_emailer.go deleted file mode 100644 index b151881..0000000 --- a/email/mock/mock_emailer.go +++ /dev/null @@ -1,23 +0,0 @@ -package emailmock - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/email" - - "github.com/stretchr/testify/mock" -) - -var _ email.Emailer = (*Emailer)(nil) - -type ( - // Emailer represents a service that can send emails. - Emailer struct { - mock.Mock - } -) - -// SendEmail is a mock function. -func (m *Emailer) SendEmail(ctx context.Context, details *email.OutboundEmailMessage) error { - return m.Called(ctx, details).Error(0) -} diff --git a/embeddings/mock/doc.go b/embeddings/mock/doc.go new file mode 100644 index 0000000..ed3ece6 --- /dev/null +++ b/embeddings/mock/doc.go @@ -0,0 +1,9 @@ +// Package mock provides mock implementations of the embeddings package's +// interfaces. Both the hand-written testify-based Embedder and the moq-generated +// EmbedderMock live here during the testify → moq migration. New test code +// should prefer EmbedderMock. +package mock + +// Regenerate the moq mocks via `go generate ./embeddings/mock/`. + +//go:generate go tool github.com/matryer/moq -out embedder_mock.go -pkg mock -rm -fmt goimports .. Embedder:EmbedderMock diff --git a/embeddings/mock/embedder_mock.go b/embeddings/mock/embedder_mock.go new file mode 100644 index 0000000..9e13e89 --- /dev/null +++ b/embeddings/mock/embedder_mock.go @@ -0,0 +1,83 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mock + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/embeddings" +) + +// Ensure, that EmbedderMock does implement embeddings.Embedder. +// If this is not the case, regenerate this file with moq. +var _ embeddings.Embedder = &EmbedderMock{} + +// EmbedderMock is a mock implementation of embeddings.Embedder. +// +// func TestSomethingThatUsesEmbedder(t *testing.T) { +// +// // make and configure a mocked embeddings.Embedder +// mockedEmbedder := &EmbedderMock{ +// GenerateEmbeddingFunc: func(ctx context.Context, input *embeddings.Input) (*embeddings.Embedding, error) { +// panic("mock out the GenerateEmbedding method") +// }, +// } +// +// // use mockedEmbedder in code that requires embeddings.Embedder +// // and then make assertions. +// +// } +type EmbedderMock struct { + // GenerateEmbeddingFunc mocks the GenerateEmbedding method. + GenerateEmbeddingFunc func(ctx context.Context, input *embeddings.Input) (*embeddings.Embedding, error) + + // calls tracks calls to the methods. + calls struct { + // GenerateEmbedding holds details about calls to the GenerateEmbedding method. + GenerateEmbedding []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Input is the input argument value. + Input *embeddings.Input + } + } + lockGenerateEmbedding sync.RWMutex +} + +// GenerateEmbedding calls GenerateEmbeddingFunc. +func (mock *EmbedderMock) GenerateEmbedding(ctx context.Context, input *embeddings.Input) (*embeddings.Embedding, error) { + if mock.GenerateEmbeddingFunc == nil { + panic("EmbedderMock.GenerateEmbeddingFunc: method is nil but Embedder.GenerateEmbedding was just called") + } + callInfo := struct { + Ctx context.Context + Input *embeddings.Input + }{ + Ctx: ctx, + Input: input, + } + mock.lockGenerateEmbedding.Lock() + mock.calls.GenerateEmbedding = append(mock.calls.GenerateEmbedding, callInfo) + mock.lockGenerateEmbedding.Unlock() + return mock.GenerateEmbeddingFunc(ctx, input) +} + +// GenerateEmbeddingCalls gets all the calls that were made to GenerateEmbedding. +// Check the length with: +// +// len(mockedEmbedder.GenerateEmbeddingCalls()) +func (mock *EmbedderMock) GenerateEmbeddingCalls() []struct { + Ctx context.Context + Input *embeddings.Input +} { + var calls []struct { + Ctx context.Context + Input *embeddings.Input + } + mock.lockGenerateEmbedding.RLock() + calls = mock.calls.GenerateEmbedding + mock.lockGenerateEmbedding.RUnlock() + return calls +} diff --git a/embeddings/mock/mock.go b/embeddings/mock/mock.go deleted file mode 100644 index b838f3d..0000000 --- a/embeddings/mock/mock.go +++ /dev/null @@ -1,25 +0,0 @@ -package mock - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/embeddings" - - "github.com/stretchr/testify/mock" -) - -var _ embeddings.Embedder = (*Embedder)(nil) - -// Embedder is a mock embeddings.Embedder for use in tests. -type Embedder struct { - mock.Mock -} - -// GenerateEmbedding satisfies the embeddings.Embedder interface. -func (m *Embedder) GenerateEmbedding(ctx context.Context, input *embeddings.Input) (*embeddings.Embedding, error) { - args := m.Called(ctx, input) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(*embeddings.Embedding), args.Error(1) -} diff --git a/embeddings/mock/mock_test.go b/embeddings/mock/mock_test.go deleted file mode 100644 index 61e247e..0000000 --- a/embeddings/mock/mock_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package mock - -import ( - "fmt" - "testing" - - "github.com/verygoodsoftwarenotvirus/platform/v5/embeddings" - - "github.com/stretchr/testify/require" -) - -func TestEmbedder_GenerateEmbedding(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &Embedder{} - input := &embeddings.Input{Content: "hello", Model: "test"} - expected := &embeddings.Embedding{ - SourceText: "hello", - Model: "test", - Provider: "mock", - } - - m.On("GenerateEmbedding", t.Context(), input).Return(expected, nil) - - ctx := t.Context() - result, err := m.GenerateEmbedding(ctx, input) - - require.NoError(t, err) - require.Equal(t, expected, result) - m.AssertExpectations(t) - }) - - T.Run("with nil result", func(t *testing.T) { - t.Parallel() - - m := &Embedder{} - input := &embeddings.Input{Content: "hello", Model: "test"} - - m.On("GenerateEmbedding", t.Context(), input).Return(nil, fmt.Errorf("embedding failed")) - - ctx := t.Context() - result, err := m.GenerateEmbedding(ctx, input) - - require.Error(t, err) - require.Nil(t, result) - m.AssertExpectations(t) - }) -} diff --git a/encoding/client_encoder_test.go b/encoding/client_encoder_test.go index 22baad5..797f2cd 100644 --- a/encoding/client_encoder_test.go +++ b/encoding/client_encoder_test.go @@ -10,11 +10,9 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/reflection" "github.com/keith-turner/ecoji/v2" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -108,8 +106,11 @@ func Test_clientEncoder_Encode(T *testing.T) { ctx := t.Context() e := ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ct) - mw := &mockWriter{} - mw.On(reflection.GetMethodName(mw.Write), mock.Anything).Return(0, errors.New("blah")) + mw := &mockWriter{ + WriteFunc: func(_ []byte) (int, error) { + return 0, errors.New("blah") + }, + } assert.Error(t, e.Encode(ctx, mw, &example{Name: t.Name()})) }) diff --git a/encoding/mock/doc.go b/encoding/mock/doc.go index 3dbce97..200bfd0 100644 --- a/encoding/mock/doc.go +++ b/encoding/mock/doc.go @@ -1,5 +1,8 @@ /* -Package mockencoding provides mockable implementations of every interface -defined in the outer encoding package. +Package mockencoding provides moq-generated mocks for the encoding package. */ package mockencoding + +// Regenerate the moq mocks via `go generate ./encoding/mock/`. + +//go:generate go tool github.com/matryer/moq -out encoder_decoder_mock.go -pkg mockencoding -rm -fmt goimports .. ServerEncoderDecoder:ServerEncoderDecoderMock ClientEncoder:ClientEncoderMock diff --git a/encoding/mock/encoder_decoder_mock.go b/encoding/mock/encoder_decoder_mock.go new file mode 100644 index 0000000..692cda3 --- /dev/null +++ b/encoding/mock/encoder_decoder_mock.go @@ -0,0 +1,530 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mockencoding + +import ( + "context" + "io" + "net/http" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/encoding" +) + +// Ensure, that ServerEncoderDecoderMock does implement encoding.ServerEncoderDecoder. +// If this is not the case, regenerate this file with moq. +var _ encoding.ServerEncoderDecoder = &ServerEncoderDecoderMock{} + +// ServerEncoderDecoderMock is a mock implementation of encoding.ServerEncoderDecoder. +// +// func TestSomethingThatUsesServerEncoderDecoder(t *testing.T) { +// +// // make and configure a mocked encoding.ServerEncoderDecoder +// mockedServerEncoderDecoder := &ServerEncoderDecoderMock{ +// DecodeBytesFunc: func(ctx context.Context, payload []byte, dest any) error { +// panic("mock out the DecodeBytes method") +// }, +// DecodeRequestFunc: func(ctx context.Context, req *http.Request, dest any) error { +// panic("mock out the DecodeRequest method") +// }, +// EncodeResponseWithStatusFunc: func(ctx context.Context, res http.ResponseWriter, val any, statusCode int) { +// panic("mock out the EncodeResponseWithStatus method") +// }, +// MustEncodeFunc: func(ctx context.Context, v any) []byte { +// panic("mock out the MustEncode method") +// }, +// MustEncodeJSONFunc: func(ctx context.Context, v any) []byte { +// panic("mock out the MustEncodeJSON method") +// }, +// } +// +// // use mockedServerEncoderDecoder in code that requires encoding.ServerEncoderDecoder +// // and then make assertions. +// +// } +type ServerEncoderDecoderMock struct { + // DecodeBytesFunc mocks the DecodeBytes method. + DecodeBytesFunc func(ctx context.Context, payload []byte, dest any) error + + // DecodeRequestFunc mocks the DecodeRequest method. + DecodeRequestFunc func(ctx context.Context, req *http.Request, dest any) error + + // EncodeResponseWithStatusFunc mocks the EncodeResponseWithStatus method. + EncodeResponseWithStatusFunc func(ctx context.Context, res http.ResponseWriter, val any, statusCode int) + + // MustEncodeFunc mocks the MustEncode method. + MustEncodeFunc func(ctx context.Context, v any) []byte + + // MustEncodeJSONFunc mocks the MustEncodeJSON method. + MustEncodeJSONFunc func(ctx context.Context, v any) []byte + + // calls tracks calls to the methods. + calls struct { + // DecodeBytes holds details about calls to the DecodeBytes method. + DecodeBytes []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Payload is the payload argument value. + Payload []byte + // Dest is the dest argument value. + Dest any + } + // DecodeRequest holds details about calls to the DecodeRequest method. + DecodeRequest []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Req is the req argument value. + Req *http.Request + // Dest is the dest argument value. + Dest any + } + // EncodeResponseWithStatus holds details about calls to the EncodeResponseWithStatus method. + EncodeResponseWithStatus []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Res is the res argument value. + Res http.ResponseWriter + // Val is the val argument value. + Val any + // StatusCode is the statusCode argument value. + StatusCode int + } + // MustEncode holds details about calls to the MustEncode method. + MustEncode []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // V is the v argument value. + V any + } + // MustEncodeJSON holds details about calls to the MustEncodeJSON method. + MustEncodeJSON []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // V is the v argument value. + V any + } + } + lockDecodeBytes sync.RWMutex + lockDecodeRequest sync.RWMutex + lockEncodeResponseWithStatus sync.RWMutex + lockMustEncode sync.RWMutex + lockMustEncodeJSON sync.RWMutex +} + +// DecodeBytes calls DecodeBytesFunc. +func (mock *ServerEncoderDecoderMock) DecodeBytes(ctx context.Context, payload []byte, dest any) error { + if mock.DecodeBytesFunc == nil { + panic("ServerEncoderDecoderMock.DecodeBytesFunc: method is nil but ServerEncoderDecoder.DecodeBytes was just called") + } + callInfo := struct { + Ctx context.Context + Payload []byte + Dest any + }{ + Ctx: ctx, + Payload: payload, + Dest: dest, + } + mock.lockDecodeBytes.Lock() + mock.calls.DecodeBytes = append(mock.calls.DecodeBytes, callInfo) + mock.lockDecodeBytes.Unlock() + return mock.DecodeBytesFunc(ctx, payload, dest) +} + +// DecodeBytesCalls gets all the calls that were made to DecodeBytes. +// Check the length with: +// +// len(mockedServerEncoderDecoder.DecodeBytesCalls()) +func (mock *ServerEncoderDecoderMock) DecodeBytesCalls() []struct { + Ctx context.Context + Payload []byte + Dest any +} { + var calls []struct { + Ctx context.Context + Payload []byte + Dest any + } + mock.lockDecodeBytes.RLock() + calls = mock.calls.DecodeBytes + mock.lockDecodeBytes.RUnlock() + return calls +} + +// DecodeRequest calls DecodeRequestFunc. +func (mock *ServerEncoderDecoderMock) DecodeRequest(ctx context.Context, req *http.Request, dest any) error { + if mock.DecodeRequestFunc == nil { + panic("ServerEncoderDecoderMock.DecodeRequestFunc: method is nil but ServerEncoderDecoder.DecodeRequest was just called") + } + callInfo := struct { + Ctx context.Context + Req *http.Request + Dest any + }{ + Ctx: ctx, + Req: req, + Dest: dest, + } + mock.lockDecodeRequest.Lock() + mock.calls.DecodeRequest = append(mock.calls.DecodeRequest, callInfo) + mock.lockDecodeRequest.Unlock() + return mock.DecodeRequestFunc(ctx, req, dest) +} + +// DecodeRequestCalls gets all the calls that were made to DecodeRequest. +// Check the length with: +// +// len(mockedServerEncoderDecoder.DecodeRequestCalls()) +func (mock *ServerEncoderDecoderMock) DecodeRequestCalls() []struct { + Ctx context.Context + Req *http.Request + Dest any +} { + var calls []struct { + Ctx context.Context + Req *http.Request + Dest any + } + mock.lockDecodeRequest.RLock() + calls = mock.calls.DecodeRequest + mock.lockDecodeRequest.RUnlock() + return calls +} + +// EncodeResponseWithStatus calls EncodeResponseWithStatusFunc. +func (mock *ServerEncoderDecoderMock) EncodeResponseWithStatus(ctx context.Context, res http.ResponseWriter, val any, statusCode int) { + if mock.EncodeResponseWithStatusFunc == nil { + panic("ServerEncoderDecoderMock.EncodeResponseWithStatusFunc: method is nil but ServerEncoderDecoder.EncodeResponseWithStatus was just called") + } + callInfo := struct { + Ctx context.Context + Res http.ResponseWriter + Val any + StatusCode int + }{ + Ctx: ctx, + Res: res, + Val: val, + StatusCode: statusCode, + } + mock.lockEncodeResponseWithStatus.Lock() + mock.calls.EncodeResponseWithStatus = append(mock.calls.EncodeResponseWithStatus, callInfo) + mock.lockEncodeResponseWithStatus.Unlock() + mock.EncodeResponseWithStatusFunc(ctx, res, val, statusCode) +} + +// EncodeResponseWithStatusCalls gets all the calls that were made to EncodeResponseWithStatus. +// Check the length with: +// +// len(mockedServerEncoderDecoder.EncodeResponseWithStatusCalls()) +func (mock *ServerEncoderDecoderMock) EncodeResponseWithStatusCalls() []struct { + Ctx context.Context + Res http.ResponseWriter + Val any + StatusCode int +} { + var calls []struct { + Ctx context.Context + Res http.ResponseWriter + Val any + StatusCode int + } + mock.lockEncodeResponseWithStatus.RLock() + calls = mock.calls.EncodeResponseWithStatus + mock.lockEncodeResponseWithStatus.RUnlock() + return calls +} + +// MustEncode calls MustEncodeFunc. +func (mock *ServerEncoderDecoderMock) MustEncode(ctx context.Context, v any) []byte { + if mock.MustEncodeFunc == nil { + panic("ServerEncoderDecoderMock.MustEncodeFunc: method is nil but ServerEncoderDecoder.MustEncode was just called") + } + callInfo := struct { + Ctx context.Context + V any + }{ + Ctx: ctx, + V: v, + } + mock.lockMustEncode.Lock() + mock.calls.MustEncode = append(mock.calls.MustEncode, callInfo) + mock.lockMustEncode.Unlock() + return mock.MustEncodeFunc(ctx, v) +} + +// MustEncodeCalls gets all the calls that were made to MustEncode. +// Check the length with: +// +// len(mockedServerEncoderDecoder.MustEncodeCalls()) +func (mock *ServerEncoderDecoderMock) MustEncodeCalls() []struct { + Ctx context.Context + V any +} { + var calls []struct { + Ctx context.Context + V any + } + mock.lockMustEncode.RLock() + calls = mock.calls.MustEncode + mock.lockMustEncode.RUnlock() + return calls +} + +// MustEncodeJSON calls MustEncodeJSONFunc. +func (mock *ServerEncoderDecoderMock) MustEncodeJSON(ctx context.Context, v any) []byte { + if mock.MustEncodeJSONFunc == nil { + panic("ServerEncoderDecoderMock.MustEncodeJSONFunc: method is nil but ServerEncoderDecoder.MustEncodeJSON was just called") + } + callInfo := struct { + Ctx context.Context + V any + }{ + Ctx: ctx, + V: v, + } + mock.lockMustEncodeJSON.Lock() + mock.calls.MustEncodeJSON = append(mock.calls.MustEncodeJSON, callInfo) + mock.lockMustEncodeJSON.Unlock() + return mock.MustEncodeJSONFunc(ctx, v) +} + +// MustEncodeJSONCalls gets all the calls that were made to MustEncodeJSON. +// Check the length with: +// +// len(mockedServerEncoderDecoder.MustEncodeJSONCalls()) +func (mock *ServerEncoderDecoderMock) MustEncodeJSONCalls() []struct { + Ctx context.Context + V any +} { + var calls []struct { + Ctx context.Context + V any + } + mock.lockMustEncodeJSON.RLock() + calls = mock.calls.MustEncodeJSON + mock.lockMustEncodeJSON.RUnlock() + return calls +} + +// Ensure, that ClientEncoderMock does implement encoding.ClientEncoder. +// If this is not the case, regenerate this file with moq. +var _ encoding.ClientEncoder = &ClientEncoderMock{} + +// ClientEncoderMock is a mock implementation of encoding.ClientEncoder. +// +// func TestSomethingThatUsesClientEncoder(t *testing.T) { +// +// // make and configure a mocked encoding.ClientEncoder +// mockedClientEncoder := &ClientEncoderMock{ +// ContentTypeFunc: func() string { +// panic("mock out the ContentType method") +// }, +// EncodeFunc: func(ctx context.Context, dest io.Writer, v any) error { +// panic("mock out the Encode method") +// }, +// EncodeReaderFunc: func(ctx context.Context, data any) (io.Reader, error) { +// panic("mock out the EncodeReader method") +// }, +// UnmarshalFunc: func(ctx context.Context, data []byte, v any) error { +// panic("mock out the Unmarshal method") +// }, +// } +// +// // use mockedClientEncoder in code that requires encoding.ClientEncoder +// // and then make assertions. +// +// } +type ClientEncoderMock struct { + // ContentTypeFunc mocks the ContentType method. + ContentTypeFunc func() string + + // EncodeFunc mocks the Encode method. + EncodeFunc func(ctx context.Context, dest io.Writer, v any) error + + // EncodeReaderFunc mocks the EncodeReader method. + EncodeReaderFunc func(ctx context.Context, data any) (io.Reader, error) + + // UnmarshalFunc mocks the Unmarshal method. + UnmarshalFunc func(ctx context.Context, data []byte, v any) error + + // calls tracks calls to the methods. + calls struct { + // ContentType holds details about calls to the ContentType method. + ContentType []struct { + } + // Encode holds details about calls to the Encode method. + Encode []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Dest is the dest argument value. + Dest io.Writer + // V is the v argument value. + V any + } + // EncodeReader holds details about calls to the EncodeReader method. + EncodeReader []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Data is the data argument value. + Data any + } + // Unmarshal holds details about calls to the Unmarshal method. + Unmarshal []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Data is the data argument value. + Data []byte + // V is the v argument value. + V any + } + } + lockContentType sync.RWMutex + lockEncode sync.RWMutex + lockEncodeReader sync.RWMutex + lockUnmarshal sync.RWMutex +} + +// ContentType calls ContentTypeFunc. +func (mock *ClientEncoderMock) ContentType() string { + if mock.ContentTypeFunc == nil { + panic("ClientEncoderMock.ContentTypeFunc: method is nil but ClientEncoder.ContentType was just called") + } + callInfo := struct { + }{} + mock.lockContentType.Lock() + mock.calls.ContentType = append(mock.calls.ContentType, callInfo) + mock.lockContentType.Unlock() + return mock.ContentTypeFunc() +} + +// ContentTypeCalls gets all the calls that were made to ContentType. +// Check the length with: +// +// len(mockedClientEncoder.ContentTypeCalls()) +func (mock *ClientEncoderMock) ContentTypeCalls() []struct { +} { + var calls []struct { + } + mock.lockContentType.RLock() + calls = mock.calls.ContentType + mock.lockContentType.RUnlock() + return calls +} + +// Encode calls EncodeFunc. +func (mock *ClientEncoderMock) Encode(ctx context.Context, dest io.Writer, v any) error { + if mock.EncodeFunc == nil { + panic("ClientEncoderMock.EncodeFunc: method is nil but ClientEncoder.Encode was just called") + } + callInfo := struct { + Ctx context.Context + Dest io.Writer + V any + }{ + Ctx: ctx, + Dest: dest, + V: v, + } + mock.lockEncode.Lock() + mock.calls.Encode = append(mock.calls.Encode, callInfo) + mock.lockEncode.Unlock() + return mock.EncodeFunc(ctx, dest, v) +} + +// EncodeCalls gets all the calls that were made to Encode. +// Check the length with: +// +// len(mockedClientEncoder.EncodeCalls()) +func (mock *ClientEncoderMock) EncodeCalls() []struct { + Ctx context.Context + Dest io.Writer + V any +} { + var calls []struct { + Ctx context.Context + Dest io.Writer + V any + } + mock.lockEncode.RLock() + calls = mock.calls.Encode + mock.lockEncode.RUnlock() + return calls +} + +// EncodeReader calls EncodeReaderFunc. +func (mock *ClientEncoderMock) EncodeReader(ctx context.Context, data any) (io.Reader, error) { + if mock.EncodeReaderFunc == nil { + panic("ClientEncoderMock.EncodeReaderFunc: method is nil but ClientEncoder.EncodeReader was just called") + } + callInfo := struct { + Ctx context.Context + Data any + }{ + Ctx: ctx, + Data: data, + } + mock.lockEncodeReader.Lock() + mock.calls.EncodeReader = append(mock.calls.EncodeReader, callInfo) + mock.lockEncodeReader.Unlock() + return mock.EncodeReaderFunc(ctx, data) +} + +// EncodeReaderCalls gets all the calls that were made to EncodeReader. +// Check the length with: +// +// len(mockedClientEncoder.EncodeReaderCalls()) +func (mock *ClientEncoderMock) EncodeReaderCalls() []struct { + Ctx context.Context + Data any +} { + var calls []struct { + Ctx context.Context + Data any + } + mock.lockEncodeReader.RLock() + calls = mock.calls.EncodeReader + mock.lockEncodeReader.RUnlock() + return calls +} + +// Unmarshal calls UnmarshalFunc. +func (mock *ClientEncoderMock) Unmarshal(ctx context.Context, data []byte, v any) error { + if mock.UnmarshalFunc == nil { + panic("ClientEncoderMock.UnmarshalFunc: method is nil but ClientEncoder.Unmarshal was just called") + } + callInfo := struct { + Ctx context.Context + Data []byte + V any + }{ + Ctx: ctx, + Data: data, + V: v, + } + mock.lockUnmarshal.Lock() + mock.calls.Unmarshal = append(mock.calls.Unmarshal, callInfo) + mock.lockUnmarshal.Unlock() + return mock.UnmarshalFunc(ctx, data, v) +} + +// UnmarshalCalls gets all the calls that were made to Unmarshal. +// Check the length with: +// +// len(mockedClientEncoder.UnmarshalCalls()) +func (mock *ClientEncoderMock) UnmarshalCalls() []struct { + Ctx context.Context + Data []byte + V any +} { + var calls []struct { + Ctx context.Context + Data []byte + V any + } + mock.lockUnmarshal.RLock() + calls = mock.calls.Unmarshal + mock.lockUnmarshal.RUnlock() + return calls +} diff --git a/encoding/mock/mock_client_encoder.go b/encoding/mock/mock_client_encoder.go deleted file mode 100644 index 01c22aa..0000000 --- a/encoding/mock/mock_client_encoder.go +++ /dev/null @@ -1,35 +0,0 @@ -package mockencoding - -import ( - "context" - "io" - - "github.com/stretchr/testify/mock" -) - -// ClientEncoder is a mock ClientEncoder. -type ClientEncoder struct { - mock.Mock -} - -// ContentType satisfies the ClientEncoder interface. -func (m *ClientEncoder) ContentType() string { - return m.Called().String(0) -} - -// Unmarshal satisfies the ClientEncoder interface. -func (m *ClientEncoder) Unmarshal(ctx context.Context, data []byte, v any) error { - return m.Called(ctx, data, v).Error(0) -} - -// Encode satisfies the ClientEncoder interface. -func (m *ClientEncoder) Encode(ctx context.Context, dest io.Writer, v any) error { - return m.Called(ctx, dest, v).Error(0) -} - -// EncodeReader satisfies the ClientEncoder interface. -func (m *ClientEncoder) EncodeReader(ctx context.Context, data any) (io.Reader, error) { - returnValues := m.Called(ctx, data) - - return returnValues.Get(0).(io.Reader), returnValues.Error(1) -} diff --git a/encoding/mock/mock_encoding.go b/encoding/mock/mock_encoding.go deleted file mode 100644 index f36b457..0000000 --- a/encoding/mock/mock_encoding.go +++ /dev/null @@ -1,83 +0,0 @@ -package mockencoding - -import ( - "context" - "net/http" - - "github.com/verygoodsoftwarenotvirus/platform/v5/encoding" - - "github.com/stretchr/testify/mock" -) - -var _ encoding.ServerEncoderDecoder = (*EncoderDecoder)(nil) - -// NewMockEncoderDecoder produces a mock EncoderDecoder. -func NewMockEncoderDecoder() *EncoderDecoder { - return &EncoderDecoder{} -} - -// EncoderDecoder is a mock EncoderDecoder. -type EncoderDecoder struct { - mock.Mock -} - -// MustEncode satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) MustEncode(ctx context.Context, v any) []byte { - return m.Called(ctx, v).Get(0).([]byte) -} - -// MustEncodeJSON satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) MustEncodeJSON(ctx context.Context, v any) []byte { - return m.Called(ctx, v).Get(0).([]byte) -} - -// RespondWithData satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) RespondWithData(ctx context.Context, res http.ResponseWriter, val any) { - m.Called(ctx, res, val) -} - -// EncodeResponseWithStatus satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) EncodeResponseWithStatus(ctx context.Context, res http.ResponseWriter, val any, statusCode int) { - m.Called(ctx, res, val, statusCode) - res.WriteHeader(statusCode) -} - -// EncodeErrorResponse satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) EncodeErrorResponse(ctx context.Context, res http.ResponseWriter, msg string, statusCode int) { - m.Called(ctx, res, msg, statusCode) - res.WriteHeader(statusCode) -} - -// EncodeInvalidInputResponse satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) EncodeInvalidInputResponse(ctx context.Context, res http.ResponseWriter) { - m.Called(ctx, res) - res.WriteHeader(http.StatusBadRequest) -} - -// EncodeNotFoundResponse satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) EncodeNotFoundResponse(ctx context.Context, res http.ResponseWriter) { - m.Called(ctx, res) - res.WriteHeader(http.StatusNotFound) -} - -// EncodeUnspecifiedInternalServerErrorResponse satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) EncodeUnspecifiedInternalServerErrorResponse(ctx context.Context, res http.ResponseWriter) { - m.Called(ctx, res) - res.WriteHeader(http.StatusInternalServerError) -} - -// EncodeUnauthorizedResponse satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) EncodeUnauthorizedResponse(ctx context.Context, res http.ResponseWriter) { - m.Called(ctx, res) - res.WriteHeader(http.StatusUnauthorized) -} - -// DecodeRequest satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) DecodeRequest(ctx context.Context, req *http.Request, v any) error { - return m.Called(ctx, req, v).Error(0) -} - -// DecodeBytes satisfies our EncoderDecoder interface. -func (m *EncoderDecoder) DecodeBytes(ctx context.Context, data []byte, v any) error { - return m.Called(ctx, data, v).Error(0) -} diff --git a/encoding/mock_io_writer_test.go b/encoding/mock_io_writer_test.go index 4f180b7..8fc6d39 100644 --- a/encoding/mock_io_writer_test.go +++ b/encoding/mock_io_writer_test.go @@ -2,19 +2,16 @@ package encoding import ( "io" - - "github.com/stretchr/testify/mock" ) var _ io.Writer = (*mockWriter)(nil) -// mockWriter mocks a io.Writer. +// mockWriter mocks an io.Writer. type mockWriter struct { - mock.Mock + WriteFunc func(p []byte) (int, error) } // Write implements the io.Writer interface. func (m *mockWriter) Write(p []byte) (int, error) { - returnVals := m.Called(p) - return returnVals.Int(0), returnVals.Error(1) + return m.WriteFunc(p) } diff --git a/featureflags/config/config_test.go b/featureflags/config/config_test.go index 4582559..f756dde 100644 --- a/featureflags/config/config_test.go +++ b/featureflags/config/config_test.go @@ -14,8 +14,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/reflection" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" @@ -196,9 +196,12 @@ func TestProvideFeatureFlagManager(T *testing.T) { cbCfg := circuitbreakingcfg.Config{} cbCfg.EnsureDefaults() - i64Counter := &mockmetrics.Int64Counter{} - mp := &mockmetrics.MetricsProvider{} - mp.On(reflection.GetMethodName(mp.NewInt64Counter), fmt.Sprintf("%s_circuit_breaker_tripped", cbCfg.Name), []metric.Int64CounterOption(nil)).Return(i64Counter, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, fmt.Sprintf("%s_circuit_breaker_tripped", cbCfg.Name), counterName) + return &mockmetrics.Int64CounterMock{}, errors.New("arbitrary") + }, + } cfg := &Config{ Provider: "", @@ -208,5 +211,7 @@ func TestProvideFeatureFlagManager(T *testing.T) { ffm, err := ProvideFeatureFlagManager(ctx, cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, http.DefaultClient) require.Error(t, err) require.Nil(t, ffm) + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/featureflags/launchdarkly/feature_flag_manager_test.go b/featureflags/launchdarkly/feature_flag_manager_test.go index dee8e04..487c0ea 100644 --- a/featureflags/launchdarkly/feature_flag_manager_test.go +++ b/featureflags/launchdarkly/feature_flag_manager_test.go @@ -268,10 +268,11 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(true) - cb.On("Succeeded").Return() - cb.On("Failed").Return() + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return true }, + SucceededFunc: func() {}, + FailedFunc: func() {}, + } ffm := buildTestManager(t, cb) @@ -284,8 +285,9 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) @@ -313,10 +315,11 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(true) - cb.On("Succeeded").Return() - cb.On("Failed").Return() + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return true }, + SucceededFunc: func() {}, + FailedFunc: func() {}, + } ffm := buildTestManager(t, cb) @@ -329,8 +332,9 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) @@ -358,10 +362,11 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(true) - cb.On("Succeeded").Return() - cb.On("Failed").Return() + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return true }, + SucceededFunc: func() {}, + FailedFunc: func() {}, + } ffm := buildTestManager(t, cb) @@ -374,8 +379,9 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) @@ -403,10 +409,11 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(true) - cb.On("Succeeded").Return() - cb.On("Failed").Return() + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return true }, + SucceededFunc: func() {}, + FailedFunc: func() {}, + } ffm := buildTestManager(t, cb) @@ -419,8 +426,9 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) @@ -449,10 +457,11 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(true) - cb.On("Succeeded").Return() - cb.On("Failed").Return() + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return true }, + SucceededFunc: func() {}, + FailedFunc: func() {}, + } ffm := buildTestManager(t, cb) @@ -466,8 +475,9 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) diff --git a/featureflags/mock/doc.go b/featureflags/mock/doc.go new file mode 100644 index 0000000..6f05bbf --- /dev/null +++ b/featureflags/mock/doc.go @@ -0,0 +1,9 @@ +// Package mock provides mock implementations of the featureflags package's +// interfaces. Both the hand-written testify-based FeatureFlagManager and the +// moq-generated FeatureFlagManagerMock live here during the testify → moq +// migration. New test code should prefer FeatureFlagManagerMock. +package mock + +// Regenerate the moq mocks via `go generate ./featureflags/mock/`. + +//go:generate go tool github.com/matryer/moq -out feature_flag_manager_mock.go -pkg mock -rm -fmt goimports .. FeatureFlagManager:FeatureFlagManagerMock diff --git a/featureflags/mock/feature_flag_manager_mock.go b/featureflags/mock/feature_flag_manager_mock.go new file mode 100644 index 0000000..fe0706d --- /dev/null +++ b/featureflags/mock/feature_flag_manager_mock.go @@ -0,0 +1,374 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mock + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/featureflags" +) + +// Ensure, that FeatureFlagManagerMock does implement featureflags.FeatureFlagManager. +// If this is not the case, regenerate this file with moq. +var _ featureflags.FeatureFlagManager = &FeatureFlagManagerMock{} + +// FeatureFlagManagerMock is a mock implementation of featureflags.FeatureFlagManager. +// +// func TestSomethingThatUsesFeatureFlagManager(t *testing.T) { +// +// // make and configure a mocked featureflags.FeatureFlagManager +// mockedFeatureFlagManager := &FeatureFlagManagerMock{ +// CanUseFeatureFunc: func(ctx context.Context, feature string, evalCtx featureflags.EvaluationContext) (bool, error) { +// panic("mock out the CanUseFeature method") +// }, +// CloseFunc: func() error { +// panic("mock out the Close method") +// }, +// GetFloat64ValueFunc: func(ctx context.Context, feature string, defaultValue float64, evalCtx featureflags.EvaluationContext) (float64, error) { +// panic("mock out the GetFloat64Value method") +// }, +// GetInt64ValueFunc: func(ctx context.Context, feature string, defaultValue int64, evalCtx featureflags.EvaluationContext) (int64, error) { +// panic("mock out the GetInt64Value method") +// }, +// GetObjectValueFunc: func(ctx context.Context, feature string, defaultValue any, evalCtx featureflags.EvaluationContext) (any, error) { +// panic("mock out the GetObjectValue method") +// }, +// GetStringValueFunc: func(ctx context.Context, feature string, defaultValue string, evalCtx featureflags.EvaluationContext) (string, error) { +// panic("mock out the GetStringValue method") +// }, +// } +// +// // use mockedFeatureFlagManager in code that requires featureflags.FeatureFlagManager +// // and then make assertions. +// +// } +type FeatureFlagManagerMock struct { + // CanUseFeatureFunc mocks the CanUseFeature method. + CanUseFeatureFunc func(ctx context.Context, feature string, evalCtx featureflags.EvaluationContext) (bool, error) + + // CloseFunc mocks the Close method. + CloseFunc func() error + + // GetFloat64ValueFunc mocks the GetFloat64Value method. + GetFloat64ValueFunc func(ctx context.Context, feature string, defaultValue float64, evalCtx featureflags.EvaluationContext) (float64, error) + + // GetInt64ValueFunc mocks the GetInt64Value method. + GetInt64ValueFunc func(ctx context.Context, feature string, defaultValue int64, evalCtx featureflags.EvaluationContext) (int64, error) + + // GetObjectValueFunc mocks the GetObjectValue method. + GetObjectValueFunc func(ctx context.Context, feature string, defaultValue any, evalCtx featureflags.EvaluationContext) (any, error) + + // GetStringValueFunc mocks the GetStringValue method. + GetStringValueFunc func(ctx context.Context, feature string, defaultValue string, evalCtx featureflags.EvaluationContext) (string, error) + + // calls tracks calls to the methods. + calls struct { + // CanUseFeature holds details about calls to the CanUseFeature method. + CanUseFeature []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Feature is the feature argument value. + Feature string + // EvalCtx is the evalCtx argument value. + EvalCtx featureflags.EvaluationContext + } + // Close holds details about calls to the Close method. + Close []struct { + } + // GetFloat64Value holds details about calls to the GetFloat64Value method. + GetFloat64Value []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Feature is the feature argument value. + Feature string + // DefaultValue is the defaultValue argument value. + DefaultValue float64 + // EvalCtx is the evalCtx argument value. + EvalCtx featureflags.EvaluationContext + } + // GetInt64Value holds details about calls to the GetInt64Value method. + GetInt64Value []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Feature is the feature argument value. + Feature string + // DefaultValue is the defaultValue argument value. + DefaultValue int64 + // EvalCtx is the evalCtx argument value. + EvalCtx featureflags.EvaluationContext + } + // GetObjectValue holds details about calls to the GetObjectValue method. + GetObjectValue []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Feature is the feature argument value. + Feature string + // DefaultValue is the defaultValue argument value. + DefaultValue any + // EvalCtx is the evalCtx argument value. + EvalCtx featureflags.EvaluationContext + } + // GetStringValue holds details about calls to the GetStringValue method. + GetStringValue []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Feature is the feature argument value. + Feature string + // DefaultValue is the defaultValue argument value. + DefaultValue string + // EvalCtx is the evalCtx argument value. + EvalCtx featureflags.EvaluationContext + } + } + lockCanUseFeature sync.RWMutex + lockClose sync.RWMutex + lockGetFloat64Value sync.RWMutex + lockGetInt64Value sync.RWMutex + lockGetObjectValue sync.RWMutex + lockGetStringValue sync.RWMutex +} + +// CanUseFeature calls CanUseFeatureFunc. +func (mock *FeatureFlagManagerMock) CanUseFeature(ctx context.Context, feature string, evalCtx featureflags.EvaluationContext) (bool, error) { + if mock.CanUseFeatureFunc == nil { + panic("FeatureFlagManagerMock.CanUseFeatureFunc: method is nil but FeatureFlagManager.CanUseFeature was just called") + } + callInfo := struct { + Ctx context.Context + Feature string + EvalCtx featureflags.EvaluationContext + }{ + Ctx: ctx, + Feature: feature, + EvalCtx: evalCtx, + } + mock.lockCanUseFeature.Lock() + mock.calls.CanUseFeature = append(mock.calls.CanUseFeature, callInfo) + mock.lockCanUseFeature.Unlock() + return mock.CanUseFeatureFunc(ctx, feature, evalCtx) +} + +// CanUseFeatureCalls gets all the calls that were made to CanUseFeature. +// Check the length with: +// +// len(mockedFeatureFlagManager.CanUseFeatureCalls()) +func (mock *FeatureFlagManagerMock) CanUseFeatureCalls() []struct { + Ctx context.Context + Feature string + EvalCtx featureflags.EvaluationContext +} { + var calls []struct { + Ctx context.Context + Feature string + EvalCtx featureflags.EvaluationContext + } + mock.lockCanUseFeature.RLock() + calls = mock.calls.CanUseFeature + mock.lockCanUseFeature.RUnlock() + return calls +} + +// Close calls CloseFunc. +func (mock *FeatureFlagManagerMock) Close() error { + if mock.CloseFunc == nil { + panic("FeatureFlagManagerMock.CloseFunc: method is nil but FeatureFlagManager.Close was just called") + } + callInfo := struct { + }{} + mock.lockClose.Lock() + mock.calls.Close = append(mock.calls.Close, callInfo) + mock.lockClose.Unlock() + return mock.CloseFunc() +} + +// CloseCalls gets all the calls that were made to Close. +// Check the length with: +// +// len(mockedFeatureFlagManager.CloseCalls()) +func (mock *FeatureFlagManagerMock) CloseCalls() []struct { +} { + var calls []struct { + } + mock.lockClose.RLock() + calls = mock.calls.Close + mock.lockClose.RUnlock() + return calls +} + +// GetFloat64Value calls GetFloat64ValueFunc. +func (mock *FeatureFlagManagerMock) GetFloat64Value(ctx context.Context, feature string, defaultValue float64, evalCtx featureflags.EvaluationContext) (float64, error) { + if mock.GetFloat64ValueFunc == nil { + panic("FeatureFlagManagerMock.GetFloat64ValueFunc: method is nil but FeatureFlagManager.GetFloat64Value was just called") + } + callInfo := struct { + Ctx context.Context + Feature string + DefaultValue float64 + EvalCtx featureflags.EvaluationContext + }{ + Ctx: ctx, + Feature: feature, + DefaultValue: defaultValue, + EvalCtx: evalCtx, + } + mock.lockGetFloat64Value.Lock() + mock.calls.GetFloat64Value = append(mock.calls.GetFloat64Value, callInfo) + mock.lockGetFloat64Value.Unlock() + return mock.GetFloat64ValueFunc(ctx, feature, defaultValue, evalCtx) +} + +// GetFloat64ValueCalls gets all the calls that were made to GetFloat64Value. +// Check the length with: +// +// len(mockedFeatureFlagManager.GetFloat64ValueCalls()) +func (mock *FeatureFlagManagerMock) GetFloat64ValueCalls() []struct { + Ctx context.Context + Feature string + DefaultValue float64 + EvalCtx featureflags.EvaluationContext +} { + var calls []struct { + Ctx context.Context + Feature string + DefaultValue float64 + EvalCtx featureflags.EvaluationContext + } + mock.lockGetFloat64Value.RLock() + calls = mock.calls.GetFloat64Value + mock.lockGetFloat64Value.RUnlock() + return calls +} + +// GetInt64Value calls GetInt64ValueFunc. +func (mock *FeatureFlagManagerMock) GetInt64Value(ctx context.Context, feature string, defaultValue int64, evalCtx featureflags.EvaluationContext) (int64, error) { + if mock.GetInt64ValueFunc == nil { + panic("FeatureFlagManagerMock.GetInt64ValueFunc: method is nil but FeatureFlagManager.GetInt64Value was just called") + } + callInfo := struct { + Ctx context.Context + Feature string + DefaultValue int64 + EvalCtx featureflags.EvaluationContext + }{ + Ctx: ctx, + Feature: feature, + DefaultValue: defaultValue, + EvalCtx: evalCtx, + } + mock.lockGetInt64Value.Lock() + mock.calls.GetInt64Value = append(mock.calls.GetInt64Value, callInfo) + mock.lockGetInt64Value.Unlock() + return mock.GetInt64ValueFunc(ctx, feature, defaultValue, evalCtx) +} + +// GetInt64ValueCalls gets all the calls that were made to GetInt64Value. +// Check the length with: +// +// len(mockedFeatureFlagManager.GetInt64ValueCalls()) +func (mock *FeatureFlagManagerMock) GetInt64ValueCalls() []struct { + Ctx context.Context + Feature string + DefaultValue int64 + EvalCtx featureflags.EvaluationContext +} { + var calls []struct { + Ctx context.Context + Feature string + DefaultValue int64 + EvalCtx featureflags.EvaluationContext + } + mock.lockGetInt64Value.RLock() + calls = mock.calls.GetInt64Value + mock.lockGetInt64Value.RUnlock() + return calls +} + +// GetObjectValue calls GetObjectValueFunc. +func (mock *FeatureFlagManagerMock) GetObjectValue(ctx context.Context, feature string, defaultValue any, evalCtx featureflags.EvaluationContext) (any, error) { + if mock.GetObjectValueFunc == nil { + panic("FeatureFlagManagerMock.GetObjectValueFunc: method is nil but FeatureFlagManager.GetObjectValue was just called") + } + callInfo := struct { + Ctx context.Context + Feature string + DefaultValue any + EvalCtx featureflags.EvaluationContext + }{ + Ctx: ctx, + Feature: feature, + DefaultValue: defaultValue, + EvalCtx: evalCtx, + } + mock.lockGetObjectValue.Lock() + mock.calls.GetObjectValue = append(mock.calls.GetObjectValue, callInfo) + mock.lockGetObjectValue.Unlock() + return mock.GetObjectValueFunc(ctx, feature, defaultValue, evalCtx) +} + +// GetObjectValueCalls gets all the calls that were made to GetObjectValue. +// Check the length with: +// +// len(mockedFeatureFlagManager.GetObjectValueCalls()) +func (mock *FeatureFlagManagerMock) GetObjectValueCalls() []struct { + Ctx context.Context + Feature string + DefaultValue any + EvalCtx featureflags.EvaluationContext +} { + var calls []struct { + Ctx context.Context + Feature string + DefaultValue any + EvalCtx featureflags.EvaluationContext + } + mock.lockGetObjectValue.RLock() + calls = mock.calls.GetObjectValue + mock.lockGetObjectValue.RUnlock() + return calls +} + +// GetStringValue calls GetStringValueFunc. +func (mock *FeatureFlagManagerMock) GetStringValue(ctx context.Context, feature string, defaultValue string, evalCtx featureflags.EvaluationContext) (string, error) { + if mock.GetStringValueFunc == nil { + panic("FeatureFlagManagerMock.GetStringValueFunc: method is nil but FeatureFlagManager.GetStringValue was just called") + } + callInfo := struct { + Ctx context.Context + Feature string + DefaultValue string + EvalCtx featureflags.EvaluationContext + }{ + Ctx: ctx, + Feature: feature, + DefaultValue: defaultValue, + EvalCtx: evalCtx, + } + mock.lockGetStringValue.Lock() + mock.calls.GetStringValue = append(mock.calls.GetStringValue, callInfo) + mock.lockGetStringValue.Unlock() + return mock.GetStringValueFunc(ctx, feature, defaultValue, evalCtx) +} + +// GetStringValueCalls gets all the calls that were made to GetStringValue. +// Check the length with: +// +// len(mockedFeatureFlagManager.GetStringValueCalls()) +func (mock *FeatureFlagManagerMock) GetStringValueCalls() []struct { + Ctx context.Context + Feature string + DefaultValue string + EvalCtx featureflags.EvaluationContext +} { + var calls []struct { + Ctx context.Context + Feature string + DefaultValue string + EvalCtx featureflags.EvaluationContext + } + mock.lockGetStringValue.RLock() + calls = mock.calls.GetStringValue + mock.lockGetStringValue.RUnlock() + return calls +} diff --git a/featureflags/mock/mock_feature_flag_manager.go b/featureflags/mock/mock_feature_flag_manager.go deleted file mode 100644 index b787fdf..0000000 --- a/featureflags/mock/mock_feature_flag_manager.go +++ /dev/null @@ -1,50 +0,0 @@ -package mock - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/featureflags" - - "github.com/stretchr/testify/mock" -) - -var _ featureflags.FeatureFlagManager = (*FeatureFlagManager)(nil) - -type FeatureFlagManager struct { - mock.Mock -} - -// CanUseFeature satisfies the FeatureFlagManager interface. -func (m *FeatureFlagManager) CanUseFeature(ctx context.Context, feature string, evalCtx featureflags.EvaluationContext) (bool, error) { - returnValues := m.Called(ctx, feature, evalCtx) - return returnValues.Bool(0), returnValues.Error(1) -} - -// GetStringValue satisfies the FeatureFlagManager interface. -func (m *FeatureFlagManager) GetStringValue(ctx context.Context, feature, defaultValue string, evalCtx featureflags.EvaluationContext) (string, error) { - returnValues := m.Called(ctx, feature, defaultValue, evalCtx) - return returnValues.String(0), returnValues.Error(1) -} - -// GetInt64Value satisfies the FeatureFlagManager interface. -func (m *FeatureFlagManager) GetInt64Value(ctx context.Context, feature string, defaultValue int64, evalCtx featureflags.EvaluationContext) (int64, error) { - returnValues := m.Called(ctx, feature, defaultValue, evalCtx) - return returnValues.Get(0).(int64), returnValues.Error(1) -} - -// GetFloat64Value satisfies the FeatureFlagManager interface. -func (m *FeatureFlagManager) GetFloat64Value(ctx context.Context, feature string, defaultValue float64, evalCtx featureflags.EvaluationContext) (float64, error) { - returnValues := m.Called(ctx, feature, defaultValue, evalCtx) - return returnValues.Get(0).(float64), returnValues.Error(1) -} - -// GetObjectValue satisfies the FeatureFlagManager interface. -func (m *FeatureFlagManager) GetObjectValue(ctx context.Context, feature string, defaultValue any, evalCtx featureflags.EvaluationContext) (any, error) { - returnValues := m.Called(ctx, feature, defaultValue, evalCtx) - return returnValues.Get(0), returnValues.Error(1) -} - -// Close satisfies the FeatureFlagManager interface. -func (m *FeatureFlagManager) Close() error { - return m.Called().Error(0) -} diff --git a/featureflags/posthog/feature_flag_manager_test.go b/featureflags/posthog/feature_flag_manager_test.go index 38c3072..b8884ca 100644 --- a/featureflags/posthog/feature_flag_manager_test.go +++ b/featureflags/posthog/feature_flag_manager_test.go @@ -256,8 +256,9 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) @@ -296,8 +297,9 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) @@ -336,8 +338,9 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) @@ -376,8 +379,9 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) @@ -418,8 +422,9 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { t.Parallel() ctx := t.Context() - cb := &mockCircuitBreaker.MockCircuitBreaker{} - cb.On("CanProceed").Return(false) + cb := &mockCircuitBreaker.CircuitBreakerMock{ + CanProceedFunc: func() bool { return false }, + } ffm := buildTestManager(t, cb) diff --git a/llm/anthropic/anthropic_test.go b/llm/anthropic/anthropic_test.go index 90e18fe..3b8b363 100644 --- a/llm/anthropic/anthropic_test.go +++ b/llm/anthropic/anthropic_test.go @@ -12,7 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" - "github.com/stretchr/testify/mock" + "github.com/shoenig/test" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -78,28 +78,41 @@ func TestNewProvider(T *testing.T) { T.Run("with error creating request counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_requests", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_requests", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) require.Error(t, err) require.Nil(t, provider) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_requests", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case name + "_requests": + return metrics.Int64CounterForTest(t, "x"), nil + case name + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) require.Error(t, err) require.Nil(t, provider) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error creating latency histogram", func(t *testing.T) { @@ -109,16 +122,22 @@ func TestNewProvider(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") require.NoError(t, histErr) - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_requests", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewFloat64Histogram", name+"_latency_ms", []metric.Float64HistogramOption(nil)).Return(h, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(histName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + test.EqOp(t, name+"_latency_ms", histName) + return h, errors.New("arbitrary") + }, + } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) require.Error(t, err) require.Nil(t, provider) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/llm/config/config_test.go b/llm/config/config_test.go index 4cd2f8a..82e886d 100644 --- a/llm/config/config_test.go +++ b/llm/config/config_test.go @@ -11,8 +11,8 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -159,14 +159,17 @@ func TestConfig_ProvideLLMProvider(T *testing.T) { }, } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.AnythingOfType("string"), []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } provider, err := cfg.ProvideLLMProvider(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp) assert.Nil(t, provider) assert.Error(t, err) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("anthropic provider with metrics error", func(t *testing.T) { @@ -180,14 +183,17 @@ func TestConfig_ProvideLLMProvider(T *testing.T) { }, } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.AnythingOfType("string"), []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } provider, err := cfg.ProvideLLMProvider(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp) assert.Nil(t, provider) assert.Error(t, err) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/llm/mock/doc.go b/llm/mock/doc.go new file mode 100644 index 0000000..a012f18 --- /dev/null +++ b/llm/mock/doc.go @@ -0,0 +1,9 @@ +// Package mock provides mock implementations of the llm package's interfaces. +// Both the hand-written testify-based Provider and the moq-generated +// ProviderMock live here during the testify → moq migration. New test code +// should prefer ProviderMock. +package mock + +// Regenerate the moq mocks via `go generate ./llm/mock/`. + +//go:generate go tool github.com/matryer/moq -out provider_mock.go -pkg mock -rm -fmt goimports .. Provider:ProviderMock diff --git a/llm/mock/mock.go b/llm/mock/mock.go deleted file mode 100644 index fdb8664..0000000 --- a/llm/mock/mock.go +++ /dev/null @@ -1,25 +0,0 @@ -package mock - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/llm" - - "github.com/stretchr/testify/mock" -) - -var _ llm.Provider = (*Provider)(nil) - -// Provider is a mock LLM provider for tests. -type Provider struct { - mock.Mock -} - -// Completion satisfies the llm.Provider interface. -func (m *Provider) Completion(ctx context.Context, params llm.CompletionParams) (*llm.CompletionResult, error) { - args := m.Called(ctx, params) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).(*llm.CompletionResult), args.Error(1) -} diff --git a/llm/mock/mock_test.go b/llm/mock/mock_test.go deleted file mode 100644 index 3523c7f..0000000 --- a/llm/mock/mock_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package mock - -import ( - "testing" - - "github.com/verygoodsoftwarenotvirus/platform/v5/llm" - - "github.com/stretchr/testify/require" -) - -func TestProvider_Completion(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &Provider{} - m.On("Completion", t.Context(), llm.CompletionParams{ - Model: "test", - Messages: []llm.Message{{Role: "user", Content: "hi"}}, - }).Return(&llm.CompletionResult{Content: "mocked"}, nil) - - ctx := t.Context() - result, err := m.Completion(ctx, llm.CompletionParams{ - Model: "test", - Messages: []llm.Message{{Role: "user", Content: "hi"}}, - }) - - require.NoError(t, err) - require.Equal(t, "mocked", result.Content) - m.AssertExpectations(t) - }) -} diff --git a/llm/mock/provider_mock.go b/llm/mock/provider_mock.go new file mode 100644 index 0000000..81b2dab --- /dev/null +++ b/llm/mock/provider_mock.go @@ -0,0 +1,83 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mock + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/llm" +) + +// Ensure, that ProviderMock does implement llm.Provider. +// If this is not the case, regenerate this file with moq. +var _ llm.Provider = &ProviderMock{} + +// ProviderMock is a mock implementation of llm.Provider. +// +// func TestSomethingThatUsesProvider(t *testing.T) { +// +// // make and configure a mocked llm.Provider +// mockedProvider := &ProviderMock{ +// CompletionFunc: func(ctx context.Context, params llm.CompletionParams) (*llm.CompletionResult, error) { +// panic("mock out the Completion method") +// }, +// } +// +// // use mockedProvider in code that requires llm.Provider +// // and then make assertions. +// +// } +type ProviderMock struct { + // CompletionFunc mocks the Completion method. + CompletionFunc func(ctx context.Context, params llm.CompletionParams) (*llm.CompletionResult, error) + + // calls tracks calls to the methods. + calls struct { + // Completion holds details about calls to the Completion method. + Completion []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params llm.CompletionParams + } + } + lockCompletion sync.RWMutex +} + +// Completion calls CompletionFunc. +func (mock *ProviderMock) Completion(ctx context.Context, params llm.CompletionParams) (*llm.CompletionResult, error) { + if mock.CompletionFunc == nil { + panic("ProviderMock.CompletionFunc: method is nil but Provider.Completion was just called") + } + callInfo := struct { + Ctx context.Context + Params llm.CompletionParams + }{ + Ctx: ctx, + Params: params, + } + mock.lockCompletion.Lock() + mock.calls.Completion = append(mock.calls.Completion, callInfo) + mock.lockCompletion.Unlock() + return mock.CompletionFunc(ctx, params) +} + +// CompletionCalls gets all the calls that were made to Completion. +// Check the length with: +// +// len(mockedProvider.CompletionCalls()) +func (mock *ProviderMock) CompletionCalls() []struct { + Ctx context.Context + Params llm.CompletionParams +} { + var calls []struct { + Ctx context.Context + Params llm.CompletionParams + } + mock.lockCompletion.RLock() + calls = mock.calls.Completion + mock.lockCompletion.RUnlock() + return calls +} diff --git a/llm/openai/openai_test.go b/llm/openai/openai_test.go index bf802e9..1050768 100644 --- a/llm/openai/openai_test.go +++ b/llm/openai/openai_test.go @@ -12,7 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" - "github.com/stretchr/testify/mock" + "github.com/shoenig/test" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -62,28 +62,41 @@ func TestNewProvider(T *testing.T) { T.Run("with error creating request counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_requests", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_requests", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) require.Error(t, err) require.Nil(t, provider) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_requests", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case name + "_requests": + return metrics.Int64CounterForTest(t, "x"), nil + case name + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) require.Error(t, err) require.Nil(t, provider) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error creating latency histogram", func(t *testing.T) { @@ -93,16 +106,22 @@ func TestNewProvider(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") require.NoError(t, histErr) - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_requests", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewFloat64Histogram", name+"_latency_ms", []metric.Float64HistogramOption(nil)).Return(h, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(histName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + test.EqOp(t, name+"_latency_ms", histName) + return h, errors.New("arbitrary") + }, + } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) require.Error(t, err) require.Nil(t, provider) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/messagequeue/kafka/consumer_test.go b/messagequeue/kafka/consumer_test.go index fc6b97c..4294182 100644 --- a/messagequeue/kafka/consumer_test.go +++ b/messagequeue/kafka/consumer_test.go @@ -9,30 +9,37 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/segmentio/kafka-go" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) type mockKafkaReader struct { - mock.Mock + fetchMessageFunc func(ctx context.Context) (kafka.Message, error) + commitMessagesFunc func(ctx context.Context, msgs ...kafka.Message) error + closeFunc func() error + fetchCalls int + commitCalls int } func (m *mockKafkaReader) FetchMessage(ctx context.Context) (kafka.Message, error) { - args := m.Called(ctx) - return args.Get(0).(kafka.Message), args.Error(1) + m.fetchCalls++ + return m.fetchMessageFunc(ctx) } func (m *mockKafkaReader) CommitMessages(ctx context.Context, msgs ...kafka.Message) error { - return m.Called(ctx, msgs).Error(0) + m.commitCalls++ + return m.commitMessagesFunc(ctx, msgs...) } func (m *mockKafkaReader) Close() error { - return m.Called().Error(0) + if m.closeFunc == nil { + return nil + } + return m.closeFunc() } func Test_kafkaConsumer_Consume(T *testing.T) { @@ -43,8 +50,11 @@ func Test_kafkaConsumer_Consume(T *testing.T) { ctx, cancel := context.WithCancel(t.Context()) - reader := &mockKafkaReader{} - reader.On("FetchMessage", testutils.ContextMatcher).Return(kafka.Message{}, context.Canceled).Maybe() + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + return kafka.Message{}, context.Canceled + }, + } c := &kafkaConsumer{ reader: reader, @@ -68,7 +78,11 @@ func Test_kafkaConsumer_Consume(T *testing.T) { ctx := t.Context() - reader := &mockKafkaReader{} + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + return kafka.Message{}, context.Canceled + }, + } c := &kafkaConsumer{ reader: reader, @@ -92,8 +106,11 @@ func Test_kafkaConsumer_Consume(T *testing.T) { ctx, cancel := context.WithCancel(t.Context()) - reader := &mockKafkaReader{} - reader.On("FetchMessage", testutils.ContextMatcher).Return(kafka.Message{}, context.Canceled).Maybe() + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + return kafka.Message{}, context.Canceled + }, + } c := &kafkaConsumer{ reader: reader, @@ -119,13 +136,15 @@ func Test_kafkaConsumer_Consume(T *testing.T) { fetchErr := errors.New("fetch failed") callCount := 0 - reader := &mockKafkaReader{} - reader.On("FetchMessage", testutils.ContextMatcher).Return(kafka.Message{}, fetchErr).Run(func(args mock.Arguments) { - callCount++ - if callCount >= 2 { - cancel() - } - }) + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + callCount++ + if callCount >= 2 { + cancel() + } + return kafka.Message{}, fetchErr + }, + } c := &kafkaConsumer{ reader: reader, @@ -157,10 +176,12 @@ func Test_kafkaConsumer_Consume(T *testing.T) { fetchErr := errors.New("fetch failed") - reader := &mockKafkaReader{} - reader.On("FetchMessage", testutils.ContextMatcher).Return(kafka.Message{}, fetchErr).Run(func(args mock.Arguments) { - cancel() - }) + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + cancel() + return kafka.Message{}, fetchErr + }, + } c := &kafkaConsumer{ reader: reader, @@ -184,10 +205,21 @@ func Test_kafkaConsumer_Consume(T *testing.T) { msg := kafka.Message{Value: []byte("test-message")} - reader := &mockKafkaReader{} - reader.On("FetchMessage", testutils.ContextMatcher).Return(msg, nil).Once() - reader.On("CommitMessages", testutils.ContextMatcher, []kafka.Message{msg}).Return(nil).Once() - reader.On("FetchMessage", testutils.ContextMatcher).Return(kafka.Message{}, context.Canceled).Maybe() + var fetchCount int + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + fetchCount++ + if fetchCount == 1 { + return msg, nil + } + return kafka.Message{}, context.Canceled + }, + commitMessagesFunc: func(_ context.Context, msgs ...kafka.Message) error { + require.Len(t, msgs, 1) + assert.Equal(t, msg, msgs[0]) + return nil + }, + } handlerCalled := false c := &kafkaConsumer{ @@ -208,8 +240,7 @@ func Test_kafkaConsumer_Consume(T *testing.T) { c.Consume(ctx, stopChan, errs) assert.True(t, handlerCalled) - - mock.AssertExpectationsForObjects(t, reader) + assert.Equal(t, 1, reader.commitCalls) }) T.Run("with handler error", func(t *testing.T) { @@ -220,9 +251,16 @@ func Test_kafkaConsumer_Consume(T *testing.T) { msg := kafka.Message{Value: []byte("test-message")} handlerErr := errors.New("handler failed") - reader := &mockKafkaReader{} - reader.On("FetchMessage", testutils.ContextMatcher).Return(msg, nil).Once() - reader.On("FetchMessage", testutils.ContextMatcher).Return(kafka.Message{}, context.Canceled).Maybe() + var fetchCount int + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + fetchCount++ + if fetchCount == 1 { + return msg, nil + } + return kafka.Message{}, context.Canceled + }, + } c := &kafkaConsumer{ reader: reader, @@ -252,9 +290,16 @@ func Test_kafkaConsumer_Consume(T *testing.T) { msg := kafka.Message{Value: []byte("test-message")} - reader := &mockKafkaReader{} - reader.On("FetchMessage", testutils.ContextMatcher).Return(msg, nil).Once() - reader.On("FetchMessage", testutils.ContextMatcher).Return(kafka.Message{}, context.Canceled).Maybe() + var fetchCount int + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + fetchCount++ + if fetchCount == 1 { + return msg, nil + } + return kafka.Message{}, context.Canceled + }, + } c := &kafkaConsumer{ reader: reader, @@ -279,10 +324,19 @@ func Test_kafkaConsumer_Consume(T *testing.T) { msg := kafka.Message{Value: []byte("test-message")} - reader := &mockKafkaReader{} - reader.On("FetchMessage", testutils.ContextMatcher).Return(msg, nil).Once() - reader.On("CommitMessages", testutils.ContextMatcher, []kafka.Message{msg}).Return(errors.New("commit failed")).Once() - reader.On("FetchMessage", testutils.ContextMatcher).Return(kafka.Message{}, context.Canceled).Maybe() + var fetchCount int + reader := &mockKafkaReader{ + fetchMessageFunc: func(_ context.Context) (kafka.Message, error) { + fetchCount++ + if fetchCount == 1 { + return msg, nil + } + return kafka.Message{}, context.Canceled + }, + commitMessagesFunc: func(_ context.Context, _ ...kafka.Message) error { + return errors.New("commit failed") + }, + } c := &kafkaConsumer{ reader: reader, @@ -299,8 +353,7 @@ func Test_kafkaConsumer_Consume(T *testing.T) { errs := make(chan error, 10) c.Consume(ctx, stopChan, errs) - - mock.AssertExpectationsForObjects(t, reader) + assert.Equal(t, 1, reader.commitCalls) }) } @@ -382,8 +435,11 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { ctx := t.Context() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.Anything, []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("counter error") + }, + } cfg := Config{ Brokers: []string{"localhost:9092"}, @@ -403,6 +459,8 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { actual, err := provider.ProvideConsumer(ctx, t.Name(), hf) assert.Error(t, err) assert.Nil(t, actual) + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with cache hit", func(t *testing.T) { diff --git a/messagequeue/kafka/publisher_test.go b/messagequeue/kafka/publisher_test.go index 47f8e33..0af4214 100644 --- a/messagequeue/kafka/publisher_test.go +++ b/messagequeue/kafka/publisher_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "io" "testing" "github.com/verygoodsoftwarenotvirus/platform/v5/encoding" @@ -13,25 +14,35 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/segmentio/kafka-go" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) type mockKafkaWriter struct { - mock.Mock + writeMessagesFunc func(ctx context.Context, msgs ...kafka.Message) error + closeFunc func() error + writeCalls int + closeCalls int } func (m *mockKafkaWriter) WriteMessages(ctx context.Context, msgs ...kafka.Message) error { - return m.Called(ctx, msgs).Error(0) + m.writeCalls++ + if m.writeMessagesFunc == nil { + return nil + } + return m.writeMessagesFunc(ctx, msgs...) } func (m *mockKafkaWriter) Close() error { - return m.Called().Error(0) + m.closeCalls++ + if m.closeFunc == nil { + return nil + } + return m.closeFunc() } func buildTestPublisher(t *testing.T) (*kafkaPublisher, *mockKafkaWriter) { @@ -70,22 +81,22 @@ func Test_kafkaPublisher_Stop(T *testing.T) { t.Parallel() pub, writer := buildTestPublisher(t) - writer.On("Close").Return(nil) + writer.closeFunc = func() error { return nil } pub.Stop() - mock.AssertExpectationsForObjects(t, writer) + assert.Equal(t, 1, writer.closeCalls) }) T.Run("with close error", func(t *testing.T) { t.Parallel() pub, writer := buildTestPublisher(t) - writer.On("Close").Return(errors.New("close failed")) + writer.closeFunc = func() error { return errors.New("close failed") } pub.Stop() - mock.AssertExpectationsForObjects(t, writer) + assert.Equal(t, 1, writer.closeCalls) }) } @@ -104,12 +115,12 @@ func Test_kafkaPublisher_Publish(T *testing.T) { Name: t.Name(), } - writer.On("WriteMessages", testutils.ContextMatcher, mock.AnythingOfType("[]kafka.Message")).Return(nil) + writer.writeMessagesFunc = func(_ context.Context, _ ...kafka.Message) error { return nil } err := pub.Publish(ctx, inputData) assert.NoError(t, err) - mock.AssertExpectationsForObjects(t, writer) + assert.Equal(t, 1, writer.writeCalls) }) T.Run("with encoding error", func(t *testing.T) { @@ -140,12 +151,12 @@ func Test_kafkaPublisher_Publish(T *testing.T) { Name: t.Name(), } - writer.On("WriteMessages", testutils.ContextMatcher, mock.AnythingOfType("[]kafka.Message")).Return(errors.New("write failed")) + writer.writeMessagesFunc = func(_ context.Context, _ ...kafka.Message) error { return errors.New("write failed") } err := pub.Publish(ctx, inputData) assert.Error(t, err) - mock.AssertExpectationsForObjects(t, writer) + assert.Equal(t, 1, writer.writeCalls) }) T.Run("with mock encoder error", func(t *testing.T) { @@ -154,14 +165,17 @@ func Test_kafkaPublisher_Publish(T *testing.T) { ctx := t.Context() pub, _ := buildTestPublisher(t) - enc := &mockencoding.ClientEncoder{} - enc.On("Encode", testutils.ContextMatcher, mock.Anything, mock.Anything).Return(errors.New("encode failed")) + enc := &mockencoding.ClientEncoderMock{ + EncodeFunc: func(_ context.Context, _ io.Writer, _ any) error { + return errors.New("encode failed") + }, + } pub.encoder = enc err := pub.Publish(ctx, "something") assert.Error(t, err) - mock.AssertExpectationsForObjects(t, enc) + test.SliceLen(t, 1, enc.EncodeCalls()) }) } @@ -180,11 +194,11 @@ func Test_kafkaPublisher_PublishAsync(T *testing.T) { Name: t.Name(), } - writer.On("WriteMessages", testutils.ContextMatcher, mock.AnythingOfType("[]kafka.Message")).Return(nil) + writer.writeMessagesFunc = func(_ context.Context, _ ...kafka.Message) error { return nil } pub.PublishAsync(ctx, inputData) - mock.AssertExpectationsForObjects(t, writer) + assert.Equal(t, 1, writer.writeCalls) }) T.Run("with encoding error", func(t *testing.T) { @@ -214,11 +228,11 @@ func Test_kafkaPublisher_PublishAsync(T *testing.T) { Name: t.Name(), } - writer.On("WriteMessages", testutils.ContextMatcher, mock.AnythingOfType("[]kafka.Message")).Return(errors.New("write failed")) + writer.writeMessagesFunc = func(_ context.Context, _ ...kafka.Message) error { return errors.New("write failed") } pub.PublishAsync(ctx, inputData) - mock.AssertExpectationsForObjects(t, writer) + assert.Equal(t, 1, writer.writeCalls) }) } @@ -327,8 +341,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { ctx := t.Context() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.Anything, []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("counter error") + }, + } cfg := Config{ Brokers: []string{"localhost:9092"}, @@ -346,6 +363,8 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { actual, err := provider.ProvidePublisher(ctx, t.Name()) assert.Error(t, err) assert.Nil(t, actual) + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating publish error counter", func(t *testing.T) { @@ -353,9 +372,13 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { ctx := t.Context() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.Anything, []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil).Once() - mp.On("NewInt64Counter", mock.Anything, []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")).Once() + mp := &mockmetrics.ProviderMock{} + mp.NewInt64CounterFunc = func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + if len(mp.NewInt64CounterCalls()) >= 2 { + return metrics.Int64CounterForTest(t, "x"), errors.New("counter error") + } + return metrics.Int64CounterForTest(t, "x"), nil + } cfg := Config{ Brokers: []string{"localhost:9092"}, @@ -373,6 +396,8 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { actual, err := provider.ProvidePublisher(ctx, t.Name()) assert.Error(t, err) assert.Nil(t, actual) + + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error creating latency histogram", func(t *testing.T) { @@ -380,9 +405,14 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { ctx := t.Context() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.Anything, []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewFloat64Histogram", mock.Anything, []metric.Float64HistogramOption(nil)).Return(&metrics.Float64HistogramImpl{}, errors.New("histogram error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(_ string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + return &metrics.Float64HistogramImpl{}, errors.New("histogram error") + }, + } cfg := Config{ Brokers: []string{"localhost:9092"}, @@ -400,6 +430,8 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { actual, err := provider.ProvidePublisher(ctx, t.Name()) assert.Error(t, err) assert.Nil(t, actual) + + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } @@ -457,8 +489,9 @@ func Test_publisherProvider_Close(T *testing.T) { require.True(t, ok) // Replace cached publisher with one using a mock writer so Close doesn't hit real Kafka - mw := &mockKafkaWriter{} - mw.On("Close").Return(nil) + mw := &mockKafkaWriter{ + closeFunc: func() error { return nil }, + } mp := metrics.NewNoopMetricsProvider() publishedCounter, _ := mp.NewInt64Counter("test_published") @@ -476,7 +509,7 @@ func Test_publisherProvider_Close(T *testing.T) { provider.Close() - mock.AssertExpectationsForObjects(t, mw) + assert.Equal(t, 1, mw.closeCalls) }) T.Run("with empty cache", func(t *testing.T) { diff --git a/messagequeue/mock/consumer.go b/messagequeue/mock/consumer.go deleted file mode 100644 index 62d01a9..0000000 --- a/messagequeue/mock/consumer.go +++ /dev/null @@ -1,35 +0,0 @@ -package mockpublishers - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" - - "github.com/stretchr/testify/mock" -) - -var _ messagequeue.ConsumerProvider = (*ConsumerProvider)(nil) - -// ConsumerProvider is a mock consumers.ConsumerProvider. -type ConsumerProvider struct { - mock.Mock -} - -// ProvideConsumer implements the interface. -func (m *ConsumerProvider) ProvideConsumer(ctx context.Context, topic string, handlerFunc messagequeue.ConsumerFunc) (messagequeue.Consumer, error) { - args := m.Called(ctx, topic, handlerFunc) - - return args.Get(0).(messagequeue.Consumer), args.Error(1) -} - -var _ messagequeue.Consumer = (*Consumer)(nil) - -// Consumer is a mock consumers.Consumer. -type Consumer struct { - mock.Mock -} - -// Consume implements the interface. -func (m *Consumer) Consume(ctx context.Context, stopChan chan bool, errors chan error) { - m.Called(ctx, stopChan, errors) -} diff --git a/messagequeue/mock/doc.go b/messagequeue/mock/doc.go new file mode 100644 index 0000000..485ce39 --- /dev/null +++ b/messagequeue/mock/doc.go @@ -0,0 +1,10 @@ +/* +Package mockpublishers provides moq-generated mocks for the messagequeue +package's Publisher, PublisherProvider, Consumer, and ConsumerProvider +interfaces. +*/ +package mockpublishers + +// Regenerate the moq mocks via `go generate ./messagequeue/mock/`. + +//go:generate go tool github.com/matryer/moq -out messagequeue_mock.go -pkg mockpublishers -rm -fmt goimports .. Publisher:PublisherMock PublisherProvider:PublisherProviderMock Consumer:ConsumerMock ConsumerProvider:ConsumerProviderMock diff --git a/messagequeue/mock/messagequeue_mock.go b/messagequeue/mock/messagequeue_mock.go new file mode 100644 index 0000000..de5367b --- /dev/null +++ b/messagequeue/mock/messagequeue_mock.go @@ -0,0 +1,479 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mockpublishers + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" +) + +// Ensure, that PublisherMock does implement messagequeue.Publisher. +// If this is not the case, regenerate this file with moq. +var _ messagequeue.Publisher = &PublisherMock{} + +// PublisherMock is a mock implementation of messagequeue.Publisher. +// +// func TestSomethingThatUsesPublisher(t *testing.T) { +// +// // make and configure a mocked messagequeue.Publisher +// mockedPublisher := &PublisherMock{ +// PublishFunc: func(ctx context.Context, data any) error { +// panic("mock out the Publish method") +// }, +// PublishAsyncFunc: func(ctx context.Context, data any) { +// panic("mock out the PublishAsync method") +// }, +// StopFunc: func() { +// panic("mock out the Stop method") +// }, +// } +// +// // use mockedPublisher in code that requires messagequeue.Publisher +// // and then make assertions. +// +// } +type PublisherMock struct { + // PublishFunc mocks the Publish method. + PublishFunc func(ctx context.Context, data any) error + + // PublishAsyncFunc mocks the PublishAsync method. + PublishAsyncFunc func(ctx context.Context, data any) + + // StopFunc mocks the Stop method. + StopFunc func() + + // calls tracks calls to the methods. + calls struct { + // Publish holds details about calls to the Publish method. + Publish []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Data is the data argument value. + Data any + } + // PublishAsync holds details about calls to the PublishAsync method. + PublishAsync []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Data is the data argument value. + Data any + } + // Stop holds details about calls to the Stop method. + Stop []struct { + } + } + lockPublish sync.RWMutex + lockPublishAsync sync.RWMutex + lockStop sync.RWMutex +} + +// Publish calls PublishFunc. +func (mock *PublisherMock) Publish(ctx context.Context, data any) error { + if mock.PublishFunc == nil { + panic("PublisherMock.PublishFunc: method is nil but Publisher.Publish was just called") + } + callInfo := struct { + Ctx context.Context + Data any + }{ + Ctx: ctx, + Data: data, + } + mock.lockPublish.Lock() + mock.calls.Publish = append(mock.calls.Publish, callInfo) + mock.lockPublish.Unlock() + return mock.PublishFunc(ctx, data) +} + +// PublishCalls gets all the calls that were made to Publish. +// Check the length with: +// +// len(mockedPublisher.PublishCalls()) +func (mock *PublisherMock) PublishCalls() []struct { + Ctx context.Context + Data any +} { + var calls []struct { + Ctx context.Context + Data any + } + mock.lockPublish.RLock() + calls = mock.calls.Publish + mock.lockPublish.RUnlock() + return calls +} + +// PublishAsync calls PublishAsyncFunc. +func (mock *PublisherMock) PublishAsync(ctx context.Context, data any) { + if mock.PublishAsyncFunc == nil { + panic("PublisherMock.PublishAsyncFunc: method is nil but Publisher.PublishAsync was just called") + } + callInfo := struct { + Ctx context.Context + Data any + }{ + Ctx: ctx, + Data: data, + } + mock.lockPublishAsync.Lock() + mock.calls.PublishAsync = append(mock.calls.PublishAsync, callInfo) + mock.lockPublishAsync.Unlock() + mock.PublishAsyncFunc(ctx, data) +} + +// PublishAsyncCalls gets all the calls that were made to PublishAsync. +// Check the length with: +// +// len(mockedPublisher.PublishAsyncCalls()) +func (mock *PublisherMock) PublishAsyncCalls() []struct { + Ctx context.Context + Data any +} { + var calls []struct { + Ctx context.Context + Data any + } + mock.lockPublishAsync.RLock() + calls = mock.calls.PublishAsync + mock.lockPublishAsync.RUnlock() + return calls +} + +// Stop calls StopFunc. +func (mock *PublisherMock) Stop() { + if mock.StopFunc == nil { + panic("PublisherMock.StopFunc: method is nil but Publisher.Stop was just called") + } + callInfo := struct { + }{} + mock.lockStop.Lock() + mock.calls.Stop = append(mock.calls.Stop, callInfo) + mock.lockStop.Unlock() + mock.StopFunc() +} + +// StopCalls gets all the calls that were made to Stop. +// Check the length with: +// +// len(mockedPublisher.StopCalls()) +func (mock *PublisherMock) StopCalls() []struct { +} { + var calls []struct { + } + mock.lockStop.RLock() + calls = mock.calls.Stop + mock.lockStop.RUnlock() + return calls +} + +// Ensure, that PublisherProviderMock does implement messagequeue.PublisherProvider. +// If this is not the case, regenerate this file with moq. +var _ messagequeue.PublisherProvider = &PublisherProviderMock{} + +// PublisherProviderMock is a mock implementation of messagequeue.PublisherProvider. +// +// func TestSomethingThatUsesPublisherProvider(t *testing.T) { +// +// // make and configure a mocked messagequeue.PublisherProvider +// mockedPublisherProvider := &PublisherProviderMock{ +// CloseFunc: func() { +// panic("mock out the Close method") +// }, +// PingFunc: func(ctx context.Context) error { +// panic("mock out the Ping method") +// }, +// ProvidePublisherFunc: func(ctx context.Context, topic string) (messagequeue.Publisher, error) { +// panic("mock out the ProvidePublisher method") +// }, +// } +// +// // use mockedPublisherProvider in code that requires messagequeue.PublisherProvider +// // and then make assertions. +// +// } +type PublisherProviderMock struct { + // CloseFunc mocks the Close method. + CloseFunc func() + + // PingFunc mocks the Ping method. + PingFunc func(ctx context.Context) error + + // ProvidePublisherFunc mocks the ProvidePublisher method. + ProvidePublisherFunc func(ctx context.Context, topic string) (messagequeue.Publisher, error) + + // calls tracks calls to the methods. + calls struct { + // Close holds details about calls to the Close method. + Close []struct { + } + // Ping holds details about calls to the Ping method. + Ping []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + // ProvidePublisher holds details about calls to the ProvidePublisher method. + ProvidePublisher []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Topic is the topic argument value. + Topic string + } + } + lockClose sync.RWMutex + lockPing sync.RWMutex + lockProvidePublisher sync.RWMutex +} + +// Close calls CloseFunc. +func (mock *PublisherProviderMock) Close() { + if mock.CloseFunc == nil { + panic("PublisherProviderMock.CloseFunc: method is nil but PublisherProvider.Close was just called") + } + callInfo := struct { + }{} + mock.lockClose.Lock() + mock.calls.Close = append(mock.calls.Close, callInfo) + mock.lockClose.Unlock() + mock.CloseFunc() +} + +// CloseCalls gets all the calls that were made to Close. +// Check the length with: +// +// len(mockedPublisherProvider.CloseCalls()) +func (mock *PublisherProviderMock) CloseCalls() []struct { +} { + var calls []struct { + } + mock.lockClose.RLock() + calls = mock.calls.Close + mock.lockClose.RUnlock() + return calls +} + +// Ping calls PingFunc. +func (mock *PublisherProviderMock) Ping(ctx context.Context) error { + if mock.PingFunc == nil { + panic("PublisherProviderMock.PingFunc: method is nil but PublisherProvider.Ping was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockPing.Lock() + mock.calls.Ping = append(mock.calls.Ping, callInfo) + mock.lockPing.Unlock() + return mock.PingFunc(ctx) +} + +// PingCalls gets all the calls that were made to Ping. +// Check the length with: +// +// len(mockedPublisherProvider.PingCalls()) +func (mock *PublisherProviderMock) PingCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockPing.RLock() + calls = mock.calls.Ping + mock.lockPing.RUnlock() + return calls +} + +// ProvidePublisher calls ProvidePublisherFunc. +func (mock *PublisherProviderMock) ProvidePublisher(ctx context.Context, topic string) (messagequeue.Publisher, error) { + if mock.ProvidePublisherFunc == nil { + panic("PublisherProviderMock.ProvidePublisherFunc: method is nil but PublisherProvider.ProvidePublisher was just called") + } + callInfo := struct { + Ctx context.Context + Topic string + }{ + Ctx: ctx, + Topic: topic, + } + mock.lockProvidePublisher.Lock() + mock.calls.ProvidePublisher = append(mock.calls.ProvidePublisher, callInfo) + mock.lockProvidePublisher.Unlock() + return mock.ProvidePublisherFunc(ctx, topic) +} + +// ProvidePublisherCalls gets all the calls that were made to ProvidePublisher. +// Check the length with: +// +// len(mockedPublisherProvider.ProvidePublisherCalls()) +func (mock *PublisherProviderMock) ProvidePublisherCalls() []struct { + Ctx context.Context + Topic string +} { + var calls []struct { + Ctx context.Context + Topic string + } + mock.lockProvidePublisher.RLock() + calls = mock.calls.ProvidePublisher + mock.lockProvidePublisher.RUnlock() + return calls +} + +// Ensure, that ConsumerMock does implement messagequeue.Consumer. +// If this is not the case, regenerate this file with moq. +var _ messagequeue.Consumer = &ConsumerMock{} + +// ConsumerMock is a mock implementation of messagequeue.Consumer. +// +// func TestSomethingThatUsesConsumer(t *testing.T) { +// +// // make and configure a mocked messagequeue.Consumer +// mockedConsumer := &ConsumerMock{ +// ConsumeFunc: func(ctx context.Context, stopChan chan bool, errors chan error) { +// panic("mock out the Consume method") +// }, +// } +// +// // use mockedConsumer in code that requires messagequeue.Consumer +// // and then make assertions. +// +// } +type ConsumerMock struct { + // ConsumeFunc mocks the Consume method. + ConsumeFunc func(ctx context.Context, stopChan chan bool, errors chan error) + + // calls tracks calls to the methods. + calls struct { + // Consume holds details about calls to the Consume method. + Consume []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // StopChan is the stopChan argument value. + StopChan chan bool + // Errors is the errors argument value. + Errors chan error + } + } + lockConsume sync.RWMutex +} + +// Consume calls ConsumeFunc. +func (mock *ConsumerMock) Consume(ctx context.Context, stopChan chan bool, errors chan error) { + if mock.ConsumeFunc == nil { + panic("ConsumerMock.ConsumeFunc: method is nil but Consumer.Consume was just called") + } + callInfo := struct { + Ctx context.Context + StopChan chan bool + Errors chan error + }{ + Ctx: ctx, + StopChan: stopChan, + Errors: errors, + } + mock.lockConsume.Lock() + mock.calls.Consume = append(mock.calls.Consume, callInfo) + mock.lockConsume.Unlock() + mock.ConsumeFunc(ctx, stopChan, errors) +} + +// ConsumeCalls gets all the calls that were made to Consume. +// Check the length with: +// +// len(mockedConsumer.ConsumeCalls()) +func (mock *ConsumerMock) ConsumeCalls() []struct { + Ctx context.Context + StopChan chan bool + Errors chan error +} { + var calls []struct { + Ctx context.Context + StopChan chan bool + Errors chan error + } + mock.lockConsume.RLock() + calls = mock.calls.Consume + mock.lockConsume.RUnlock() + return calls +} + +// Ensure, that ConsumerProviderMock does implement messagequeue.ConsumerProvider. +// If this is not the case, regenerate this file with moq. +var _ messagequeue.ConsumerProvider = &ConsumerProviderMock{} + +// ConsumerProviderMock is a mock implementation of messagequeue.ConsumerProvider. +// +// func TestSomethingThatUsesConsumerProvider(t *testing.T) { +// +// // make and configure a mocked messagequeue.ConsumerProvider +// mockedConsumerProvider := &ConsumerProviderMock{ +// ProvideConsumerFunc: func(ctx context.Context, topic string, handlerFunc messagequeue.ConsumerFunc) (messagequeue.Consumer, error) { +// panic("mock out the ProvideConsumer method") +// }, +// } +// +// // use mockedConsumerProvider in code that requires messagequeue.ConsumerProvider +// // and then make assertions. +// +// } +type ConsumerProviderMock struct { + // ProvideConsumerFunc mocks the ProvideConsumer method. + ProvideConsumerFunc func(ctx context.Context, topic string, handlerFunc messagequeue.ConsumerFunc) (messagequeue.Consumer, error) + + // calls tracks calls to the methods. + calls struct { + // ProvideConsumer holds details about calls to the ProvideConsumer method. + ProvideConsumer []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Topic is the topic argument value. + Topic string + // HandlerFunc is the handlerFunc argument value. + HandlerFunc messagequeue.ConsumerFunc + } + } + lockProvideConsumer sync.RWMutex +} + +// ProvideConsumer calls ProvideConsumerFunc. +func (mock *ConsumerProviderMock) ProvideConsumer(ctx context.Context, topic string, handlerFunc messagequeue.ConsumerFunc) (messagequeue.Consumer, error) { + if mock.ProvideConsumerFunc == nil { + panic("ConsumerProviderMock.ProvideConsumerFunc: method is nil but ConsumerProvider.ProvideConsumer was just called") + } + callInfo := struct { + Ctx context.Context + Topic string + HandlerFunc messagequeue.ConsumerFunc + }{ + Ctx: ctx, + Topic: topic, + HandlerFunc: handlerFunc, + } + mock.lockProvideConsumer.Lock() + mock.calls.ProvideConsumer = append(mock.calls.ProvideConsumer, callInfo) + mock.lockProvideConsumer.Unlock() + return mock.ProvideConsumerFunc(ctx, topic, handlerFunc) +} + +// ProvideConsumerCalls gets all the calls that were made to ProvideConsumer. +// Check the length with: +// +// len(mockedConsumerProvider.ProvideConsumerCalls()) +func (mock *ConsumerProviderMock) ProvideConsumerCalls() []struct { + Ctx context.Context + Topic string + HandlerFunc messagequeue.ConsumerFunc +} { + var calls []struct { + Ctx context.Context + Topic string + HandlerFunc messagequeue.ConsumerFunc + } + mock.lockProvideConsumer.RLock() + calls = mock.calls.ProvideConsumer + mock.lockProvideConsumer.RUnlock() + return calls +} diff --git a/messagequeue/mock/mock.go b/messagequeue/mock/mock.go deleted file mode 100644 index 3b29c65..0000000 --- a/messagequeue/mock/mock.go +++ /dev/null @@ -1,53 +0,0 @@ -package mockpublishers - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" - - "github.com/stretchr/testify/mock" -) - -var _ messagequeue.Publisher = (*Publisher)(nil) - -// Publisher implements our interface. -type Publisher struct { - mock.Mock -} - -// Publish implements our interface. -func (m *Publisher) Publish(ctx context.Context, data any) error { - return m.Called(ctx, data).Error(0) -} - -// PublishAsync implements our interface. -func (m *Publisher) PublishAsync(ctx context.Context, data any) { - m.Called(ctx, data) -} - -// Stop implements our interface. -func (m *Publisher) Stop() { - m.Called() -} - -// PublisherProvider implements our interface. -type PublisherProvider struct { - mock.Mock -} - -// ProvidePublisher implements our interface. -func (m *PublisherProvider) ProvidePublisher(ctx context.Context, topic string) (messagequeue.Publisher, error) { - args := m.Called(topic) - - return args.Get(0).(messagequeue.Publisher), args.Error(1) -} - -// Close implements our interface. -func (m *PublisherProvider) Close() { - m.Called() -} - -// Ping implements our interface. -func (m *PublisherProvider) Ping(ctx context.Context) error { - return m.Called(ctx).Error(0) -} diff --git a/messagequeue/pubsub/consumer_test.go b/messagequeue/pubsub/consumer_test.go index 3b40af8..81357db 100644 --- a/messagequeue/pubsub/consumer_test.go +++ b/messagequeue/pubsub/consumer_test.go @@ -14,15 +14,16 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/verygoodsoftwarenotvirus/platform/v5/random" "cloud.google.com/go/pubsub/v2" "cloud.google.com/go/pubsub/v2/apiv1/pubsubpb" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" tcpubsub "github.com/testcontainers/testcontainers-go/modules/gcloud/pubsub" + "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" "google.golang.org/api/option" "google.golang.org/grpc" @@ -132,14 +133,15 @@ func TestBuildPubSubConsumer(T *testing.T) { T.Run("panics when NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_consumed", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(string, ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metricnoop.Int64Counter{}, errors.New("forced error") + }, + } assert.Panics(t, func() { buildPubSubConsumer(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) - - mock.AssertExpectationsForObjects(t, mp) }) } diff --git a/messagequeue/pubsub/publisher_test.go b/messagequeue/pubsub/publisher_test.go index fcae43e..1331421 100644 --- a/messagequeue/pubsub/publisher_test.go +++ b/messagequeue/pubsub/publisher_test.go @@ -7,11 +7,12 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) @@ -28,43 +29,57 @@ func TestBuildPubSubPublisher(T *testing.T) { T.Run("panics when first NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + if name == "t_published" { + return metricnoop.Int64Counter{}, errors.New("forced error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", name) + return nil, nil + }, + } assert.Panics(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("panics when second NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "t_publish_errors", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch name { + case "t_published": + return metricnoop.Int64Counter{}, nil + case "t_publish_errors": + return metricnoop.Int64Counter{}, errors.New("forced error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", name) + return nil, nil + }, + } assert.Panics(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("panics when NewFloat64Histogram fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "t_publish_errors", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewFloat64Histogram", "t_publish_latency_ms", mock.Anything).Return(metricnoop.Float64Histogram{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(string, ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metricnoop.Int64Counter{}, nil + }, + NewFloat64HistogramFunc: func(string, ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + return metricnoop.Float64Histogram{}, errors.New("forced error") + }, + } assert.Panics(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) } diff --git a/messagequeue/redis/consumer_test.go b/messagequeue/redis/consumer_test.go index 1a3470f..c25567b 100644 --- a/messagequeue/redis/consumer_test.go +++ b/messagequeue/redis/consumer_test.go @@ -9,11 +9,12 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) @@ -174,13 +175,14 @@ func Test_provideRedisConsumer(T *testing.T) { T.Run("panics when NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_consumed", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(string, ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metricnoop.Int64Counter{}, errors.New("forced error") + }, + } assert.Panics(t, func() { provideRedisConsumer(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) - - mock.AssertExpectationsForObjects(t, mp) }) } diff --git a/messagequeue/redis/publisher_test.go b/messagequeue/redis/publisher_test.go index a4e25f2..a02443d 100644 --- a/messagequeue/redis/publisher_test.go +++ b/messagequeue/redis/publisher_test.go @@ -10,33 +10,40 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/go-redis/redis/v8" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) type mockMessagePublisher struct { - mock.Mock + publishFunc func(ctx context.Context, channel string, message any) *redis.IntCmd + closeFunc func() error + pingFunc func(ctx context.Context) *redis.StatusCmd + publishArgs []publishCall +} + +type publishCall struct { + ctx context.Context + message any + channel string } -// Publish implements the interface. func (m *mockMessagePublisher) Publish(ctx context.Context, channel string, message any) *redis.IntCmd { - return m.Called(ctx, channel, message).Get(0).(*redis.IntCmd) + m.publishArgs = append(m.publishArgs, publishCall{ctx: ctx, channel: channel, message: message}) + return m.publishFunc(ctx, channel, message) } -// Close implements the interface. func (m *mockMessagePublisher) Close() error { - return m.Called().Error(0) + return m.closeFunc() } -// Ping implements the interface. func (m *mockMessagePublisher) Ping(ctx context.Context) *redis.StatusCmd { - return m.Called(ctx).Get(0).(*redis.StatusCmd) + return m.pingFunc(ctx) } func buildRedisBackedPublisher(t *testing.T, cfg *Config, topic string) messagequeue.Publisher { @@ -84,20 +91,18 @@ func Test_redisPublisher_Publish(T *testing.T) { Name: t.Name(), } - mmp := &mockMessagePublisher{} - mmp.On( - "Publish", - testutils.ContextMatcher, - actual.topic, - fmt.Appendf(nil, `{"name":%q}%s`, t.Name(), string(byte(10))), - ).Return(&redis.IntCmd{}) + mmp := &mockMessagePublisher{ + publishFunc: func(_ context.Context, _ string, _ any) *redis.IntCmd { return &redis.IntCmd{} }, + } actual.publisher = mmp err = actual.Publish(ctx, inputData) assert.NoError(t, err) - mock.AssertExpectationsForObjects(t, mmp) + require.Len(t, mmp.publishArgs, 1) + assert.Equal(t, actual.topic, mmp.publishArgs[0].channel) + assert.Equal(t, fmt.Appendf(nil, `{"name":%q}%s`, t.Name(), string(byte(10))), mmp.publishArgs[0].message) }) T.Run("with error encoding value", func(t *testing.T) { @@ -158,19 +163,17 @@ func Test_redisPublisher_PublishAsync(T *testing.T) { Name: t.Name(), } - mmp := &mockMessagePublisher{} - mmp.On( - "Publish", - testutils.ContextMatcher, - actual.topic, - fmt.Appendf(nil, `{"name":%q}%s`, t.Name(), string(byte(10))), - ).Return(&redis.IntCmd{}) + mmp := &mockMessagePublisher{ + publishFunc: func(_ context.Context, _ string, _ any) *redis.IntCmd { return &redis.IntCmd{} }, + } actual.publisher = mmp actual.PublishAsync(ctx, inputData) - mock.AssertExpectationsForObjects(t, mmp) + require.Len(t, mmp.publishArgs, 1) + assert.Equal(t, actual.topic, mmp.publishArgs[0].channel) + assert.Equal(t, fmt.Appendf(nil, `{"name":%q}%s`, t.Name(), string(byte(10))), mmp.publishArgs[0].message) }) T.Run("with error encoding value", func(t *testing.T) { @@ -290,42 +293,56 @@ func Test_provideRedisPublisher(T *testing.T) { T.Run("panics when first NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + if name == "t_published" { + return metricnoop.Int64Counter{}, errors.New("forced error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", name) + return nil, nil + }, + } assert.Panics(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("panics when second NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "t_publish_errors", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch name { + case "t_published": + return metricnoop.Int64Counter{}, nil + case "t_publish_errors": + return metricnoop.Int64Counter{}, errors.New("forced error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", name) + return nil, nil + }, + } assert.Panics(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("panics when NewFloat64Histogram fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "t_publish_errors", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewFloat64Histogram", "t_publish_latency_ms", mock.Anything).Return(metricnoop.Float64Histogram{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(string, ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metricnoop.Int64Counter{}, nil + }, + NewFloat64HistogramFunc: func(string, ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + return metricnoop.Float64Histogram{}, errors.New("forced error") + }, + } assert.Panics(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) } diff --git a/messagequeue/sqs/consumer_test.go b/messagequeue/sqs/consumer_test.go index e2e3eb5..f8bf957 100644 --- a/messagequeue/sqs/consumer_test.go +++ b/messagequeue/sqs/consumer_test.go @@ -8,38 +8,31 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) type mockMessageReceiver struct { - mock.Mock + receiveMessageFunc func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) + deleteMessageFunc func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error) + deleteMessageCalls int } func (m *mockMessageReceiver) ReceiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { - retVals := m.Called(ctx, input, optFns) - out := retVals.Get(0) - if out == nil { - return nil, retVals.Error(1) - } - return out.(*sqs.ReceiveMessageOutput), retVals.Error(1) + return m.receiveMessageFunc(ctx, input, optFns...) } func (m *mockMessageReceiver) DeleteMessage(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error) { - retVals := m.Called(ctx, input, optFns) - out := retVals.Get(0) - if out == nil { - return nil, retVals.Error(1) - } - return out.(*sqs.DeleteMessageOutput), retVals.Error(1) + m.deleteMessageCalls++ + return m.deleteMessageFunc(ctx, input, optFns...) } func Test_sqsConsumer_Consume(T *testing.T) { @@ -50,45 +43,36 @@ func Test_sqsConsumer_Consume(T *testing.T) { T.Run("successful message handling and deletion", func(t *testing.T) { t.Parallel() - mmr := &mockMessageReceiver{} - mmr.On( - "ReceiveMessage", - testutils.ContextMatcher, - mock.MatchedBy(func(in *sqs.ReceiveMessageInput) bool { - return aws.ToString(in.QueueUrl) == queueURL && - in.MaxNumberOfMessages == maxNumberOfMessages && - in.WaitTimeSeconds == longPollWaitSeconds - }), - mock.Anything, - ).Return(&sqs.ReceiveMessageOutput{ - Messages: []types.Message{ - { - Body: aws.String("test-payload"), - ReceiptHandle: aws.String("receipt-handle-123"), - }, + deleteCalled := make(chan struct{}, 1) + var receiveCalls int + mmr := &mockMessageReceiver{ + receiveMessageFunc: func(_ context.Context, in *sqs.ReceiveMessageInput, _ ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + receiveCalls++ + if receiveCalls == 1 { + assert.Equal(t, queueURL, aws.ToString(in.QueueUrl)) + assert.Equal(t, int32(maxNumberOfMessages), in.MaxNumberOfMessages) + assert.Equal(t, int32(longPollWaitSeconds), in.WaitTimeSeconds) + return &sqs.ReceiveMessageOutput{ + Messages: []types.Message{ + { + Body: aws.String("test-payload"), + ReceiptHandle: aws.String("receipt-handle-123"), + }, + }, + }, nil + } + return &sqs.ReceiveMessageOutput{Messages: []types.Message{}}, nil }, - }, nil).Once() - - mmr.On( - "ReceiveMessage", - testutils.ContextMatcher, - mock.Anything, - mock.Anything, - ).Return(&sqs.ReceiveMessageOutput{Messages: []types.Message{}}, nil) - - deleteCalled := make(chan struct{}) - mmr.On( - "DeleteMessage", - testutils.ContextMatcher, - mock.MatchedBy(func(in *sqs.DeleteMessageInput) bool { - return aws.ToString(in.QueueUrl) == queueURL && - aws.ToString(in.ReceiptHandle) == "receipt-handle-123" - }), - mock.Anything, - ).Run(func(args mock.Arguments) { deleteCalled <- struct{}{} }).Return(&sqs.DeleteMessageOutput{}, nil).Once() + deleteMessageFunc: func(_ context.Context, in *sqs.DeleteMessageInput, _ ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error) { + assert.Equal(t, queueURL, aws.ToString(in.QueueUrl)) + assert.Equal(t, "receipt-handle-123", aws.ToString(in.ReceiptHandle)) + deleteCalled <- struct{}{} + return &sqs.DeleteMessageOutput{}, nil + }, + } handlerDone := make(chan []byte, 1) - handler := func(ctx context.Context, body []byte) error { + handler := func(_ context.Context, body []byte) error { handlerDone <- body return nil } @@ -104,36 +88,35 @@ func Test_sqsConsumer_Consume(T *testing.T) { stopChan <- true assert.Equal(t, []byte("test-payload"), receivedBody) - mock.AssertExpectationsForObjects(t, mmr) }) T.Run("handler error does not delete message", func(t *testing.T) { t.Parallel() anticipatedErr := errors.New("handler failed") - mmr := &mockMessageReceiver{} - mmr.On( - "ReceiveMessage", - testutils.ContextMatcher, - mock.Anything, - mock.Anything, - ).Return(&sqs.ReceiveMessageOutput{ - Messages: []types.Message{ - { - Body: aws.String("fail-payload"), - ReceiptHandle: aws.String("receipt-handle-456"), - }, + var receiveCalls int + mmr := &mockMessageReceiver{ + receiveMessageFunc: func(_ context.Context, _ *sqs.ReceiveMessageInput, _ ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + receiveCalls++ + if receiveCalls == 1 { + return &sqs.ReceiveMessageOutput{ + Messages: []types.Message{ + { + Body: aws.String("fail-payload"), + ReceiptHandle: aws.String("receipt-handle-456"), + }, + }, + }, nil + } + return &sqs.ReceiveMessageOutput{Messages: []types.Message{}}, nil }, - }, nil).Once() - - mmr.On( - "ReceiveMessage", - testutils.ContextMatcher, - mock.Anything, - mock.Anything, - ).Return(&sqs.ReceiveMessageOutput{Messages: []types.Message{}}, nil) + deleteMessageFunc: func(_ context.Context, _ *sqs.DeleteMessageInput, _ ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error) { + t.Fatal("DeleteMessage should not be called when handler errors") + return nil, nil + }, + } - handler := func(ctx context.Context, body []byte) error { + handler := func(_ context.Context, _ []byte) error { return anticipatedErr } @@ -149,8 +132,7 @@ func Test_sqsConsumer_Consume(T *testing.T) { stopChan <- true - mmr.AssertNotCalled(t, "DeleteMessage") - mock.AssertExpectationsForObjects(t, mmr) + assert.Zero(t, mmr.deleteMessageCalls) }) } @@ -242,13 +224,14 @@ func Test_provideSQSConsumer(T *testing.T) { T.Run("panics when NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_consumed", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(string, ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metricnoop.Int64Counter{}, errors.New("forced error") + }, + } assert.Panics(t, func() { provideSQSConsumer(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) - - mock.AssertExpectationsForObjects(t, mp) }) } diff --git a/messagequeue/sqs/publisher_test.go b/messagequeue/sqs/publisher_test.go index b4b12e3..5426c8f 100644 --- a/messagequeue/sqs/publisher_test.go +++ b/messagequeue/sqs/publisher_test.go @@ -9,23 +9,24 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) type mockMessagePublisher struct { - mock.Mock + sendMessageFunc func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) + sendMessageCalls int } func (m *mockMessagePublisher) SendMessage(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { - retVals := m.Called(ctx, input, optFns) - return retVals.Get(0).(*sqs.SendMessageOutput), retVals.Error(1) + m.sendMessageCalls++ + return m.sendMessageFunc(ctx, input, optFns...) } func Test_sqsPublisher_Publish(T *testing.T) { @@ -53,20 +54,17 @@ func Test_sqsPublisher_Publish(T *testing.T) { Name: t.Name(), } - mmp := &mockMessagePublisher{} - mmp.On( - "SendMessage", - testutils.ContextMatcher, - mock.MatchedBy(func(*sqs.SendMessageInput) bool { return true }), - mock.Anything, - ).Return(&sqs.SendMessageOutput{}, nil) + mmp := &mockMessagePublisher{ + sendMessageFunc: func(_ context.Context, _ *sqs.SendMessageInput, _ ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { + return &sqs.SendMessageOutput{}, nil + }, + } actual.publisher = mmp err = actual.Publish(ctx, inputData) assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, mmp) + assert.Equal(t, 1, mmp.sendMessageCalls) }) T.Run("with error encoding value", func(t *testing.T) { @@ -121,19 +119,16 @@ func Test_sqsPublisher_PublishAsync(T *testing.T) { Name: t.Name(), } - mmp := &mockMessagePublisher{} - mmp.On( - "SendMessage", - testutils.ContextMatcher, - mock.MatchedBy(func(*sqs.SendMessageInput) bool { return true }), - mock.Anything, - ).Return(&sqs.SendMessageOutput{}, nil) + mmp := &mockMessagePublisher{ + sendMessageFunc: func(_ context.Context, _ *sqs.SendMessageInput, _ ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { + return &sqs.SendMessageOutput{}, nil + }, + } actual.publisher = mmp actual.PublishAsync(ctx, inputData) - - mock.AssertExpectationsForObjects(t, mmp) + assert.Equal(t, 1, mmp.sendMessageCalls) }) T.Run("with error encoding value", func(t *testing.T) { @@ -183,19 +178,16 @@ func Test_sqsPublisher_PublishAsync(T *testing.T) { Name: t.Name(), } - mmp := &mockMessagePublisher{} - mmp.On( - "SendMessage", - testutils.ContextMatcher, - mock.MatchedBy(func(*sqs.SendMessageInput) bool { return true }), - mock.Anything, - ).Return((*sqs.SendMessageOutput)(nil), errors.New("send failed")) + mmp := &mockMessagePublisher{ + sendMessageFunc: func(_ context.Context, _ *sqs.SendMessageInput, _ ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { + return nil, errors.New("send failed") + }, + } actual.publisher = mmp actual.PublishAsync(ctx, inputData) - - mock.AssertExpectationsForObjects(t, mmp) + assert.Equal(t, 1, mmp.sendMessageCalls) }) } @@ -276,42 +268,56 @@ func Test_provideSQSPublisher(T *testing.T) { T.Run("panics when first NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + if name == "t_published" { + return metricnoop.Int64Counter{}, errors.New("forced error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", name) + return nil, nil + }, + } assert.Panics(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("panics when second NewInt64Counter fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "t_publish_errors", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch name { + case "t_published": + return metricnoop.Int64Counter{}, nil + case "t_publish_errors": + return metricnoop.Int64Counter{}, errors.New("forced error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", name) + return nil, nil + }, + } assert.Panics(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("panics when NewFloat64Histogram fails", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "t_published", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "t_publish_errors", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewFloat64Histogram", "t_publish_latency_ms", mock.Anything).Return(metricnoop.Float64Histogram{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(string, ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metricnoop.Int64Counter{}, nil + }, + NewFloat64HistogramFunc: func(string, ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + return metricnoop.Float64Histogram{}, errors.New("forced error") + }, + } assert.Panics(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) - - mock.AssertExpectationsForObjects(t, mp) }) } diff --git a/notifications/mobile/apns/apns_sender_test.go b/notifications/mobile/apns/apns_sender_test.go index 0a43ca0..dc31670 100644 --- a/notifications/mobile/apns/apns_sender_test.go +++ b/notifications/mobile/apns/apns_sender_test.go @@ -19,9 +19,10 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric" ) const validDeviceToken = "a1b2c3d4e5f67890a1b2c3d4e5f67890a1b2c3d4e5f67890a1b2c3d4e5f67890" @@ -208,14 +209,19 @@ func TestNewSender(T *testing.T) { BundleID: "com.example.app", } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", o11yName+"_sends", mock.Anything). - Return((*metrics.Int64CounterImpl)(nil), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + assert.Equal(t, o11yName+"_sends", counterName) + return (*metrics.Int64CounterImpl)(nil), errors.New("counter error") + }, + } sender, err := NewSender(cfg, tracingProvider, logger, mp) assert.Nil(t, sender) require.Error(t, err) assert.Contains(t, err.Error(), "creating send counter") + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error counter creation error", func(t *testing.T) { @@ -229,16 +235,25 @@ func TestNewSender(T *testing.T) { BundleID: "com.example.app", } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", o11yName+"_sends", mock.Anything). - Return((*metrics.Int64CounterImpl)(nil), nil) - mp.On("NewInt64Counter", o11yName+"_errors", mock.Anything). - Return((*metrics.Int64CounterImpl)(nil), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case o11yName + "_sends": + return (*metrics.Int64CounterImpl)(nil), nil + case o11yName + "_errors": + return (*metrics.Int64CounterImpl)(nil), errors.New("counter error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } sender, err := NewSender(cfg, tracingProvider, logger, mp) assert.Nil(t, sender) require.Error(t, err) assert.Contains(t, err.Error(), "creating error counter") + + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) } diff --git a/notifications/mobile/fcm/fcm_sender_test.go b/notifications/mobile/fcm/fcm_sender_test.go index 46dfc6b..b7659aa 100644 --- a/notifications/mobile/fcm/fcm_sender_test.go +++ b/notifications/mobile/fcm/fcm_sender_test.go @@ -15,9 +15,10 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" firebase "firebase.google.com/go/v4" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric" "google.golang.org/api/option" ) @@ -134,15 +135,20 @@ func TestNewSender(T *testing.T) { path := filepath.Join(dir, "creds.json") require.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", o11yName+"_sends", mock.Anything). - Return((*metrics.Int64CounterImpl)(nil), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + assert.Equal(t, o11yName+"_sends", counterName) + return (*metrics.Int64CounterImpl)(nil), errors.New("counter error") + }, + } cfg := &Config{CredentialsPath: path} sender, err := NewSender(ctx, cfg, tracingProvider, logger, mp) assert.Nil(t, sender) require.Error(t, err) assert.Contains(t, err.Error(), "creating send counter") + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error counter creation error", func(t *testing.T) { @@ -152,17 +158,26 @@ func TestNewSender(T *testing.T) { path := filepath.Join(dir, "creds.json") require.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", o11yName+"_sends", mock.Anything). - Return((*metrics.Int64CounterImpl)(nil), nil) - mp.On("NewInt64Counter", o11yName+"_errors", mock.Anything). - Return((*metrics.Int64CounterImpl)(nil), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case o11yName + "_sends": + return (*metrics.Int64CounterImpl)(nil), nil + case o11yName + "_errors": + return (*metrics.Int64CounterImpl)(nil), errors.New("counter error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } cfg := &Config{CredentialsPath: path} sender, err := NewSender(ctx, cfg, tracingProvider, logger, mp) assert.Nil(t, sender) require.Error(t, err) assert.Contains(t, err.Error(), "creating error counter") + + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) } diff --git a/observability/metrics/mock/doc.go b/observability/metrics/mock/doc.go index 9e104a0..b2571d8 100644 --- a/observability/metrics/mock/doc.go +++ b/observability/metrics/mock/doc.go @@ -1,4 +1,8 @@ /* -Package mockmetrics provides metrics-related mocks. +Package mockmetrics provides moq-generated mocks for the metrics package. */ package mockmetrics + +// Regenerate the moq mocks via `go generate ./observability/metrics/mock/`. + +//go:generate go tool github.com/matryer/moq -out provider_mock.go -pkg mockmetrics -rm -fmt goimports .. Provider:ProviderMock Int64Counter:Int64CounterMock diff --git a/observability/metrics/mock/int64_counter.go b/observability/metrics/mock/int64_counter.go deleted file mode 100644 index 5c1e32f..0000000 --- a/observability/metrics/mock/int64_counter.go +++ /dev/null @@ -1,20 +0,0 @@ -package mockmetrics - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - - "github.com/stretchr/testify/mock" - "go.opentelemetry.io/otel/metric" -) - -var _ metrics.Int64Counter = (*Int64Counter)(nil) - -type Int64Counter struct { - mock.Mock -} - -func (m *Int64Counter) Add(ctx context.Context, incr int64, options ...metric.AddOption) { - m.Called(ctx, incr, options) -} diff --git a/observability/metrics/mock/provider.go b/observability/metrics/mock/provider.go deleted file mode 100644 index b3d2e7b..0000000 --- a/observability/metrics/mock/provider.go +++ /dev/null @@ -1,74 +0,0 @@ -package mockmetrics - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - - "github.com/stretchr/testify/mock" - "go.opentelemetry.io/otel/metric" -) - -type MetricsProvider struct { - mock.Mock -} - -// NewFloat64Counter is a mock method. -func (m *MetricsProvider) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (metrics.Float64Counter, error) { - args := m.Called(name, options) - return args.Get(0).(metrics.Float64Counter), args.Error(1) -} - -// NewFloat64Gauge is a mock method. -func (m *MetricsProvider) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (metrics.Float64Gauge, error) { - args := m.Called(name, options) - return args.Get(0).(metrics.Float64Gauge), args.Error(1) -} - -// NewFloat64UpDownCounter is a mock method. -func (m *MetricsProvider) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (metrics.Float64UpDownCounter, error) { - args := m.Called(name, options) - return args.Get(0).(metrics.Float64UpDownCounter), args.Error(1) -} - -// NewFloat64Histogram is a mock method. -func (m *MetricsProvider) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { - args := m.Called(name, options) - return args.Get(0).(metrics.Float64Histogram), args.Error(1) -} - -// NewInt64Counter is a mock method. -func (m *MetricsProvider) NewInt64Counter(name string, options ...metric.Int64CounterOption) (metrics.Int64Counter, error) { - args := m.Called(name, options) - return args.Get(0).(metrics.Int64Counter), args.Error(1) -} - -// NewInt64Gauge is a mock method. -func (m *MetricsProvider) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (metrics.Int64Gauge, error) { - args := m.Called(name, options) - return args.Get(0).(metrics.Int64Gauge), args.Error(1) -} - -// NewInt64UpDownCounter is a mock method. -func (m *MetricsProvider) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (metrics.Int64UpDownCounter, error) { - args := m.Called(name, options) - return args.Get(0).(metrics.Int64UpDownCounter), args.Error(1) -} - -// NewInt64Histogram is a mock method. -func (m *MetricsProvider) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (metrics.Int64Histogram, error) { - args := m.Called(name, options) - return args.Get(0).(metrics.Int64Histogram), args.Error(1) -} - -// Shutdown is a mock method. -func (m *MetricsProvider) Shutdown(ctx context.Context) error { - args := m.Called(ctx) - return args.Error(0) -} - -// MeterProvider is a mock method. -func (m *MetricsProvider) MeterProvider() metric.MeterProvider { - args := m.Called() - return args.Get(0).(metric.MeterProvider) -} diff --git a/observability/metrics/mock2/provider_mock.go b/observability/metrics/mock/provider_mock.go similarity index 89% rename from observability/metrics/mock2/provider_mock.go rename to observability/metrics/mock/provider_mock.go index 2272295..18cd8b1 100644 --- a/observability/metrics/mock2/provider_mock.go +++ b/observability/metrics/mock/provider_mock.go @@ -1,13 +1,14 @@ // Code generated by moq; DO NOT EDIT. // github.com/matryer/moq -package mock2 +package mockmetrics import ( "context" "sync" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + "go.opentelemetry.io/otel/metric" ) @@ -513,3 +514,81 @@ func (mock *ProviderMock) ShutdownCalls() []struct { mock.lockShutdown.RUnlock() return calls } + +// Ensure, that Int64CounterMock does implement metrics.Int64Counter. +// If this is not the case, regenerate this file with moq. +var _ metrics.Int64Counter = &Int64CounterMock{} + +// Int64CounterMock is a mock implementation of metrics.Int64Counter. +// +// func TestSomethingThatUsesInt64Counter(t *testing.T) { +// +// // make and configure a mocked metrics.Int64Counter +// mockedInt64Counter := &Int64CounterMock{ +// AddFunc: func(ctx context.Context, incr int64, options ...metric.AddOption) { +// panic("mock out the Add method") +// }, +// } +// +// // use mockedInt64Counter in code that requires metrics.Int64Counter +// // and then make assertions. +// +// } +type Int64CounterMock struct { + // AddFunc mocks the Add method. + AddFunc func(ctx context.Context, incr int64, options ...metric.AddOption) + + // calls tracks calls to the methods. + calls struct { + // Add holds details about calls to the Add method. + Add []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Incr is the incr argument value. + Incr int64 + // Options is the options argument value. + Options []metric.AddOption + } + } + lockAdd sync.RWMutex +} + +// Add calls AddFunc. +func (mock *Int64CounterMock) Add(ctx context.Context, incr int64, options ...metric.AddOption) { + if mock.AddFunc == nil { + panic("Int64CounterMock.AddFunc: method is nil but Int64Counter.Add was just called") + } + callInfo := struct { + Ctx context.Context + Incr int64 + Options []metric.AddOption + }{ + Ctx: ctx, + Incr: incr, + Options: options, + } + mock.lockAdd.Lock() + mock.calls.Add = append(mock.calls.Add, callInfo) + mock.lockAdd.Unlock() + mock.AddFunc(ctx, incr, options...) +} + +// AddCalls gets all the calls that were made to Add. +// Check the length with: +// +// len(mockedInt64Counter.AddCalls()) +func (mock *Int64CounterMock) AddCalls() []struct { + Ctx context.Context + Incr int64 + Options []metric.AddOption +} { + var calls []struct { + Ctx context.Context + Incr int64 + Options []metric.AddOption + } + mock.lockAdd.RLock() + calls = mock.calls.Add + mock.lockAdd.RUnlock() + return calls +} diff --git a/observability/metrics/mock2/doc.go b/observability/metrics/mock2/doc.go deleted file mode 100644 index f3d3315..0000000 --- a/observability/metrics/mock2/doc.go +++ /dev/null @@ -1,10 +0,0 @@ -// Package mock2 provides moq-generated mock implementations of interfaces in -// the observability/metrics package. It exists alongside the hand-written -// testify-based package observability/metrics/mock and is a pilot of the -// matryer/moq workflow; consumers that want the moq style should import this -// package instead of observability/metrics/mock. -package mock2 - -// Regenerate via `go generate ./observability/metrics/mock2/`. - -//go:generate go tool github.com/matryer/moq -out provider_mock.go -pkg mock2 -rm -fmt goimports .. Provider:ProviderMock diff --git a/observability/metrics/mock_provider.go b/observability/metrics/mock_provider.go deleted file mode 100644 index c5c8fff..0000000 --- a/observability/metrics/mock_provider.go +++ /dev/null @@ -1,72 +0,0 @@ -package metrics - -import ( - "context" - - "github.com/stretchr/testify/mock" - "go.opentelemetry.io/otel/metric" -) - -var _ Provider = (*MockProvider)(nil) - -type MockProvider struct { - mock.Mock -} - -// NewFloat64Counter satisfies our interface. -func (m *MockProvider) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (Float64Counter, error) { - returnValues := m.Called(name, options) - return returnValues.Get(0).(metric.Float64Counter), returnValues.Error(1) -} - -// NewFloat64Gauge satisfies our interface. -func (m *MockProvider) NewFloat64Gauge(name string, options ...metric.Float64GaugeOption) (Float64Gauge, error) { - returnValues := m.Called(name, options) - return returnValues.Get(0).(metric.Float64Gauge), returnValues.Error(1) -} - -// NewFloat64UpDownCounter satisfies our interface. -func (m *MockProvider) NewFloat64UpDownCounter(name string, options ...metric.Float64UpDownCounterOption) (Float64UpDownCounter, error) { - returnValues := m.Called(name, options) - return returnValues.Get(0).(metric.Float64UpDownCounter), returnValues.Error(1) -} - -// NewFloat64Histogram satisfies our interface. -func (m *MockProvider) NewFloat64Histogram(name string, options ...metric.Float64HistogramOption) (Float64Histogram, error) { - returnValues := m.Called(name, options) - return returnValues.Get(0).(metric.Float64Histogram), returnValues.Error(1) -} - -// NewInt64Counter satisfies our interface. -func (m *MockProvider) NewInt64Counter(name string, options ...metric.Int64CounterOption) (Int64Counter, error) { - returnValues := m.Called(name, options) - return returnValues.Get(0).(metric.Int64Counter), returnValues.Error(1) -} - -// NewInt64Gauge satisfies our interface. -func (m *MockProvider) NewInt64Gauge(name string, options ...metric.Int64GaugeOption) (Int64Gauge, error) { - returnValues := m.Called(name, options) - return returnValues.Get(0).(metric.Int64Gauge), returnValues.Error(1) -} - -// NewInt64UpDownCounter satisfies our interface. -func (m *MockProvider) NewInt64UpDownCounter(name string, options ...metric.Int64UpDownCounterOption) (Int64UpDownCounter, error) { - returnValues := m.Called(name, options) - return returnValues.Get(0).(metric.Int64UpDownCounter), returnValues.Error(1) -} - -// NewInt64Histogram satisfies our interface. -func (m *MockProvider) NewInt64Histogram(name string, options ...metric.Int64HistogramOption) (Int64Histogram, error) { - returnValues := m.Called(name, options) - return returnValues.Get(0).(metric.Int64Histogram), returnValues.Error(1) -} - -// Shutdown satisfies our interface. -func (m *MockProvider) Shutdown(ctx context.Context) error { - return m.Called(ctx).Error(0) -} - -// MeterProvider satisfies our interface. -func (m *MockProvider) MeterProvider() metric.MeterProvider { - return m.Called().Get(0).(metric.MeterProvider) -} diff --git a/observability/metrics/noop.go b/observability/metrics/noop.go index 5520e27..cf187d4 100644 --- a/observability/metrics/noop.go +++ b/observability/metrics/noop.go @@ -3,7 +3,6 @@ package metrics import ( "context" - "github.com/stretchr/testify/mock" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" @@ -13,9 +12,7 @@ func NewNoopMetricsProvider() Provider { return &noopProvider{} } -type noopProvider struct { - mock.Mock -} +type noopProvider struct{} // NewFloat64Counter is a no-op method. func (m *noopProvider) NewFloat64Counter(name string, options ...metric.Float64CounterOption) (Float64Counter, error) { diff --git a/panicking/mock/doc.go b/panicking/mock/doc.go new file mode 100644 index 0000000..eb00756 --- /dev/null +++ b/panicking/mock/doc.go @@ -0,0 +1,9 @@ +// Package mockpanicking provides mock implementations of the panicking package's +// interfaces. Both the hand-written testify-based Panicker and the moq-generated +// PanickerMock live here during the testify → moq migration. New test code +// should prefer PanickerMock. +package mockpanicking + +// Regenerate the moq mocks via `go generate ./panicking/mock/`. + +//go:generate go tool github.com/matryer/moq -out panicker_mock.go -pkg mockpanicking -rm -fmt goimports .. Panicker:PanickerMock diff --git a/panicking/mock/mock.go b/panicking/mock/mock.go deleted file mode 100644 index a093419..0000000 --- a/panicking/mock/mock.go +++ /dev/null @@ -1,25 +0,0 @@ -package mockpanicking - -import ( - "github.com/stretchr/testify/mock" -) - -// Panicker implements Panicker for tests. -type Panicker struct { - mock.Mock -} - -// NewMockPanicker produces a production-ready panicker that will actually panic when called. -func NewMockPanicker() *Panicker { - return &Panicker{} -} - -// Panic satisfies our interface. -func (p *Panicker) Panic(msg any) { - p.Called(msg) -} - -// Panicf satisfies our interface. -func (p *Panicker) Panicf(format string, args ...any) { - p.Called(format, args) -} diff --git a/panicking/mock/panicker_mock.go b/panicking/mock/panicker_mock.go new file mode 100644 index 0000000..f737cd8 --- /dev/null +++ b/panicking/mock/panicker_mock.go @@ -0,0 +1,126 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mockpanicking + +import ( + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/panicking" +) + +// Ensure, that PanickerMock does implement panicking.Panicker. +// If this is not the case, regenerate this file with moq. +var _ panicking.Panicker = &PanickerMock{} + +// PanickerMock is a mock implementation of panicking.Panicker. +// +// func TestSomethingThatUsesPanicker(t *testing.T) { +// +// // make and configure a mocked panicking.Panicker +// mockedPanicker := &PanickerMock{ +// PanicFunc: func(v any) { +// panic("mock out the Panic method") +// }, +// PanicfFunc: func(format string, args ...any) { +// panic("mock out the Panicf method") +// }, +// } +// +// // use mockedPanicker in code that requires panicking.Panicker +// // and then make assertions. +// +// } +type PanickerMock struct { + // PanicFunc mocks the Panic method. + PanicFunc func(v any) + + // PanicfFunc mocks the Panicf method. + PanicfFunc func(format string, args ...any) + + // calls tracks calls to the methods. + calls struct { + // Panic holds details about calls to the Panic method. + Panic []struct { + // V is the v argument value. + V any + } + // Panicf holds details about calls to the Panicf method. + Panicf []struct { + // Format is the format argument value. + Format string + // Args is the args argument value. + Args []any + } + } + lockPanic sync.RWMutex + lockPanicf sync.RWMutex +} + +// Panic calls PanicFunc. +func (mock *PanickerMock) Panic(v any) { + if mock.PanicFunc == nil { + panic("PanickerMock.PanicFunc: method is nil but Panicker.Panic was just called") + } + callInfo := struct { + V any + }{ + V: v, + } + mock.lockPanic.Lock() + mock.calls.Panic = append(mock.calls.Panic, callInfo) + mock.lockPanic.Unlock() + mock.PanicFunc(v) +} + +// PanicCalls gets all the calls that were made to Panic. +// Check the length with: +// +// len(mockedPanicker.PanicCalls()) +func (mock *PanickerMock) PanicCalls() []struct { + V any +} { + var calls []struct { + V any + } + mock.lockPanic.RLock() + calls = mock.calls.Panic + mock.lockPanic.RUnlock() + return calls +} + +// Panicf calls PanicfFunc. +func (mock *PanickerMock) Panicf(format string, args ...any) { + if mock.PanicfFunc == nil { + panic("PanickerMock.PanicfFunc: method is nil but Panicker.Panicf was just called") + } + callInfo := struct { + Format string + Args []any + }{ + Format: format, + Args: args, + } + mock.lockPanicf.Lock() + mock.calls.Panicf = append(mock.calls.Panicf, callInfo) + mock.lockPanicf.Unlock() + mock.PanicfFunc(format, args...) +} + +// PanicfCalls gets all the calls that were made to Panicf. +// Check the length with: +// +// len(mockedPanicker.PanicfCalls()) +func (mock *PanickerMock) PanicfCalls() []struct { + Format string + Args []any +} { + var calls []struct { + Format string + Args []any + } + mock.lockPanicf.RLock() + calls = mock.calls.Panicf + mock.lockPanicf.RUnlock() + return calls +} diff --git a/random/mock/doc.go b/random/mock/doc.go new file mode 100644 index 0000000..4bb0bd2 --- /dev/null +++ b/random/mock/doc.go @@ -0,0 +1,9 @@ +// Package randommock provides mock implementations of the random package's +// interfaces. Both the hand-written testify-based Generator and the +// moq-generated GeneratorMock live here during the testify → moq migration. +// New test code should prefer GeneratorMock. +package randommock + +// Regenerate the moq mocks via `go generate ./random/mock/`. + +//go:generate go tool github.com/matryer/moq -out generator_mock.go -pkg randommock -rm -fmt goimports .. Generator:GeneratorMock diff --git a/random/mock/generator_mock.go b/random/mock/generator_mock.go new file mode 100644 index 0000000..1848ffa --- /dev/null +++ b/random/mock/generator_mock.go @@ -0,0 +1,233 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package randommock + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/random" +) + +// Ensure, that GeneratorMock does implement random.Generator. +// If this is not the case, regenerate this file with moq. +var _ random.Generator = &GeneratorMock{} + +// GeneratorMock is a mock implementation of random.Generator. +// +// func TestSomethingThatUsesGenerator(t *testing.T) { +// +// // make and configure a mocked random.Generator +// mockedGenerator := &GeneratorMock{ +// GenerateBase32EncodedStringFunc: func(contextMoqParam context.Context, n int) (string, error) { +// panic("mock out the GenerateBase32EncodedString method") +// }, +// GenerateBase64EncodedStringFunc: func(contextMoqParam context.Context, n int) (string, error) { +// panic("mock out the GenerateBase64EncodedString method") +// }, +// GenerateHexEncodedStringFunc: func(ctx context.Context, length int) (string, error) { +// panic("mock out the GenerateHexEncodedString method") +// }, +// GenerateRawBytesFunc: func(contextMoqParam context.Context, n int) ([]byte, error) { +// panic("mock out the GenerateRawBytes method") +// }, +// } +// +// // use mockedGenerator in code that requires random.Generator +// // and then make assertions. +// +// } +type GeneratorMock struct { + // GenerateBase32EncodedStringFunc mocks the GenerateBase32EncodedString method. + GenerateBase32EncodedStringFunc func(contextMoqParam context.Context, n int) (string, error) + + // GenerateBase64EncodedStringFunc mocks the GenerateBase64EncodedString method. + GenerateBase64EncodedStringFunc func(contextMoqParam context.Context, n int) (string, error) + + // GenerateHexEncodedStringFunc mocks the GenerateHexEncodedString method. + GenerateHexEncodedStringFunc func(ctx context.Context, length int) (string, error) + + // GenerateRawBytesFunc mocks the GenerateRawBytes method. + GenerateRawBytesFunc func(contextMoqParam context.Context, n int) ([]byte, error) + + // calls tracks calls to the methods. + calls struct { + // GenerateBase32EncodedString holds details about calls to the GenerateBase32EncodedString method. + GenerateBase32EncodedString []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // N is the n argument value. + N int + } + // GenerateBase64EncodedString holds details about calls to the GenerateBase64EncodedString method. + GenerateBase64EncodedString []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // N is the n argument value. + N int + } + // GenerateHexEncodedString holds details about calls to the GenerateHexEncodedString method. + GenerateHexEncodedString []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Length is the length argument value. + Length int + } + // GenerateRawBytes holds details about calls to the GenerateRawBytes method. + GenerateRawBytes []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // N is the n argument value. + N int + } + } + lockGenerateBase32EncodedString sync.RWMutex + lockGenerateBase64EncodedString sync.RWMutex + lockGenerateHexEncodedString sync.RWMutex + lockGenerateRawBytes sync.RWMutex +} + +// GenerateBase32EncodedString calls GenerateBase32EncodedStringFunc. +func (mock *GeneratorMock) GenerateBase32EncodedString(contextMoqParam context.Context, n int) (string, error) { + if mock.GenerateBase32EncodedStringFunc == nil { + panic("GeneratorMock.GenerateBase32EncodedStringFunc: method is nil but Generator.GenerateBase32EncodedString was just called") + } + callInfo := struct { + ContextMoqParam context.Context + N int + }{ + ContextMoqParam: contextMoqParam, + N: n, + } + mock.lockGenerateBase32EncodedString.Lock() + mock.calls.GenerateBase32EncodedString = append(mock.calls.GenerateBase32EncodedString, callInfo) + mock.lockGenerateBase32EncodedString.Unlock() + return mock.GenerateBase32EncodedStringFunc(contextMoqParam, n) +} + +// GenerateBase32EncodedStringCalls gets all the calls that were made to GenerateBase32EncodedString. +// Check the length with: +// +// len(mockedGenerator.GenerateBase32EncodedStringCalls()) +func (mock *GeneratorMock) GenerateBase32EncodedStringCalls() []struct { + ContextMoqParam context.Context + N int +} { + var calls []struct { + ContextMoqParam context.Context + N int + } + mock.lockGenerateBase32EncodedString.RLock() + calls = mock.calls.GenerateBase32EncodedString + mock.lockGenerateBase32EncodedString.RUnlock() + return calls +} + +// GenerateBase64EncodedString calls GenerateBase64EncodedStringFunc. +func (mock *GeneratorMock) GenerateBase64EncodedString(contextMoqParam context.Context, n int) (string, error) { + if mock.GenerateBase64EncodedStringFunc == nil { + panic("GeneratorMock.GenerateBase64EncodedStringFunc: method is nil but Generator.GenerateBase64EncodedString was just called") + } + callInfo := struct { + ContextMoqParam context.Context + N int + }{ + ContextMoqParam: contextMoqParam, + N: n, + } + mock.lockGenerateBase64EncodedString.Lock() + mock.calls.GenerateBase64EncodedString = append(mock.calls.GenerateBase64EncodedString, callInfo) + mock.lockGenerateBase64EncodedString.Unlock() + return mock.GenerateBase64EncodedStringFunc(contextMoqParam, n) +} + +// GenerateBase64EncodedStringCalls gets all the calls that were made to GenerateBase64EncodedString. +// Check the length with: +// +// len(mockedGenerator.GenerateBase64EncodedStringCalls()) +func (mock *GeneratorMock) GenerateBase64EncodedStringCalls() []struct { + ContextMoqParam context.Context + N int +} { + var calls []struct { + ContextMoqParam context.Context + N int + } + mock.lockGenerateBase64EncodedString.RLock() + calls = mock.calls.GenerateBase64EncodedString + mock.lockGenerateBase64EncodedString.RUnlock() + return calls +} + +// GenerateHexEncodedString calls GenerateHexEncodedStringFunc. +func (mock *GeneratorMock) GenerateHexEncodedString(ctx context.Context, length int) (string, error) { + if mock.GenerateHexEncodedStringFunc == nil { + panic("GeneratorMock.GenerateHexEncodedStringFunc: method is nil but Generator.GenerateHexEncodedString was just called") + } + callInfo := struct { + Ctx context.Context + Length int + }{ + Ctx: ctx, + Length: length, + } + mock.lockGenerateHexEncodedString.Lock() + mock.calls.GenerateHexEncodedString = append(mock.calls.GenerateHexEncodedString, callInfo) + mock.lockGenerateHexEncodedString.Unlock() + return mock.GenerateHexEncodedStringFunc(ctx, length) +} + +// GenerateHexEncodedStringCalls gets all the calls that were made to GenerateHexEncodedString. +// Check the length with: +// +// len(mockedGenerator.GenerateHexEncodedStringCalls()) +func (mock *GeneratorMock) GenerateHexEncodedStringCalls() []struct { + Ctx context.Context + Length int +} { + var calls []struct { + Ctx context.Context + Length int + } + mock.lockGenerateHexEncodedString.RLock() + calls = mock.calls.GenerateHexEncodedString + mock.lockGenerateHexEncodedString.RUnlock() + return calls +} + +// GenerateRawBytes calls GenerateRawBytesFunc. +func (mock *GeneratorMock) GenerateRawBytes(contextMoqParam context.Context, n int) ([]byte, error) { + if mock.GenerateRawBytesFunc == nil { + panic("GeneratorMock.GenerateRawBytesFunc: method is nil but Generator.GenerateRawBytes was just called") + } + callInfo := struct { + ContextMoqParam context.Context + N int + }{ + ContextMoqParam: contextMoqParam, + N: n, + } + mock.lockGenerateRawBytes.Lock() + mock.calls.GenerateRawBytes = append(mock.calls.GenerateRawBytes, callInfo) + mock.lockGenerateRawBytes.Unlock() + return mock.GenerateRawBytesFunc(contextMoqParam, n) +} + +// GenerateRawBytesCalls gets all the calls that were made to GenerateRawBytes. +// Check the length with: +// +// len(mockedGenerator.GenerateRawBytesCalls()) +func (mock *GeneratorMock) GenerateRawBytesCalls() []struct { + ContextMoqParam context.Context + N int +} { + var calls []struct { + ContextMoqParam context.Context + N int + } + mock.lockGenerateRawBytes.RLock() + calls = mock.calls.GenerateRawBytes + mock.lockGenerateRawBytes.RUnlock() + return calls +} diff --git a/random/mock/mock_random.go b/random/mock/mock_random.go deleted file mode 100644 index 36584f8..0000000 --- a/random/mock/mock_random.go +++ /dev/null @@ -1,42 +0,0 @@ -package randommock - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/random" - - "github.com/stretchr/testify/mock" -) - -var _ random.Generator = (*Generator)(nil) - -// Generator is a mock Generator. -type Generator struct { - mock.Mock -} - -func (m *Generator) GenerateHexEncodedString(ctx context.Context, length int) (string, error) { - returnVals := m.Called(ctx, length) - return returnVals.String(0), returnVals.Error(1) -} - -// GenerateBase32EncodedString implements our interface. -func (m *Generator) GenerateBase32EncodedString(ctx context.Context, length int) (string, error) { - returnVals := m.Called(ctx, length) - - return returnVals.String(0), returnVals.Error(1) -} - -// GenerateBase64EncodedString implements our interface. -func (m *Generator) GenerateBase64EncodedString(ctx context.Context, length int) (string, error) { - returnVals := m.Called(ctx, length) - - return returnVals.String(0), returnVals.Error(1) -} - -// GenerateRawBytes implements our interface. -func (m *Generator) GenerateRawBytes(ctx context.Context, length int) ([]byte, error) { - returnVals := m.Called(ctx, length) - - return returnVals.Get(0).([]byte), returnVals.Error(1) -} diff --git a/ratelimiting/redis/redis_test.go b/ratelimiting/redis/redis_test.go index f03c73a..7c2ecea 100644 --- a/ratelimiting/redis/redis_test.go +++ b/ratelimiting/redis/redis_test.go @@ -7,25 +7,36 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" - "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/go-redis/redis/v8" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) +type evalCall struct { + ctx context.Context + script string + keys []string + args []any +} + type mockRedisClient struct { - mock.Mock + evalFunc func(ctx context.Context, script string, keys []string, args ...any) *redis.Cmd + closeFunc func() error + evalCalls []evalCall + closeCalls int } func (m *mockRedisClient) Eval(ctx context.Context, script string, keys []string, args ...any) *redis.Cmd { - return m.Called(ctx, script, keys, args).Get(0).(*redis.Cmd) + m.evalCalls = append(m.evalCalls, evalCall{ctx: ctx, script: script, keys: keys, args: args}) + return m.evalFunc(ctx, script, keys, args...) } func (m *mockRedisClient) Close() error { - return m.Called().Error(0) + m.closeCalls++ + return m.closeFunc() } func buildTestRateLimiter(t *testing.T) (*rateLimiter, *mockRedisClient) { @@ -128,14 +139,18 @@ func TestNewRedisRateLimiter(T *testing.T) { Addresses: []string{"localhost:6379"}, } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", redisName+"_allowed", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, redisName+"_allowed", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("counter error") + }, + } rl, err := NewRedisRateLimiter(cfg, mp, 10) assert.Error(t, err) assert.Nil(t, rl) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating rejected counter", func(t *testing.T) { @@ -145,15 +160,24 @@ func TestNewRedisRateLimiter(T *testing.T) { Addresses: []string{"localhost:6379"}, } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", redisName+"_allowed", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", redisName+"_rejected", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case redisName + "_allowed": + return metrics.Int64CounterForTest(t, "x"), nil + case redisName + "_rejected": + return metrics.Int64CounterForTest(t, "x"), errors.New("counter error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } rl, err := NewRedisRateLimiter(cfg, mp, 10) assert.Error(t, err) assert.Nil(t, rl) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { @@ -163,16 +187,24 @@ func TestNewRedisRateLimiter(T *testing.T) { Addresses: []string{"localhost:6379"}, } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", redisName+"_allowed", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", redisName+"_rejected", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", redisName+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("counter error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case redisName + "_allowed", redisName + "_rejected": + return metrics.Int64CounterForTest(t, "x"), nil + case redisName + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("counter error") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } rl, err := NewRedisRateLimiter(cfg, mp, 10) assert.Error(t, err) assert.Nil(t, rl) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 3, mp.NewInt64CounterCalls()) }) } @@ -187,13 +219,14 @@ func Test_rateLimiter_Allow(T *testing.T) { cmd := redis.NewCmd(ctx) cmd.SetVal(int64(1)) - client.On("Eval", testutils.ContextMatcher, slidingWindowScript, mock.AnythingOfType("[]string"), mock.AnythingOfType("[]interface {}")).Return(cmd) + client.evalFunc = func(_ context.Context, _ string, _ []string, _ ...any) *redis.Cmd { return cmd } allowed, err := rl.Allow(ctx, "test-key") assert.NoError(t, err) assert.True(t, allowed) - mock.AssertExpectationsForObjects(t, client) + require.Len(t, client.evalCalls, 1) + assert.Equal(t, slidingWindowScript, client.evalCalls[0].script) }) T.Run("rejected", func(t *testing.T) { @@ -204,13 +237,13 @@ func Test_rateLimiter_Allow(T *testing.T) { cmd := redis.NewCmd(ctx) cmd.SetVal(int64(0)) - client.On("Eval", testutils.ContextMatcher, slidingWindowScript, mock.AnythingOfType("[]string"), mock.AnythingOfType("[]interface {}")).Return(cmd) + client.evalFunc = func(_ context.Context, _ string, _ []string, _ ...any) *redis.Cmd { return cmd } allowed, err := rl.Allow(ctx, "test-key") assert.NoError(t, err) assert.False(t, allowed) - mock.AssertExpectationsForObjects(t, client) + require.Len(t, client.evalCalls, 1) }) T.Run("with eval error", func(t *testing.T) { @@ -221,13 +254,13 @@ func Test_rateLimiter_Allow(T *testing.T) { cmd := redis.NewCmd(ctx) cmd.SetErr(errors.New("redis error")) - client.On("Eval", testutils.ContextMatcher, slidingWindowScript, mock.AnythingOfType("[]string"), mock.AnythingOfType("[]interface {}")).Return(cmd) + client.evalFunc = func(_ context.Context, _ string, _ []string, _ ...any) *redis.Cmd { return cmd } allowed, err := rl.Allow(ctx, "test-key") assert.Error(t, err) assert.False(t, allowed) - mock.AssertExpectationsForObjects(t, client) + require.Len(t, client.evalCalls, 1) }) } @@ -238,23 +271,21 @@ func Test_rateLimiter_Close(T *testing.T) { t.Parallel() rl, client := buildTestRateLimiter(t) - client.On("Close").Return(nil) + client.closeFunc = func() error { return nil } err := rl.Close() assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, client) + assert.Equal(t, 1, client.closeCalls) }) T.Run("with close error", func(t *testing.T) { t.Parallel() rl, client := buildTestRateLimiter(t) - client.On("Close").Return(errors.New("close failed")) + client.closeFunc = func() error { return errors.New("close failed") } err := rl.Close() assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, client) + assert.Equal(t, 1, client.closeCalls) }) } diff --git a/routing/mock/doc.go b/routing/mock/doc.go new file mode 100644 index 0000000..9c91d55 --- /dev/null +++ b/routing/mock/doc.go @@ -0,0 +1,9 @@ +// Package mockrouting provides mock implementations of the routing package's +// interfaces. Both the hand-written testify-based RouteParamManager and the +// moq-generated RouteParamManagerMock live here during the testify → moq +// migration. New test code should prefer RouteParamManagerMock. +package mockrouting + +// Regenerate the moq mocks via `go generate ./routing/mock/`. + +//go:generate go tool github.com/matryer/moq -out route_param_manager_mock.go -pkg mockrouting -rm -fmt goimports .. RouteParamManager:RouteParamManagerMock diff --git a/routing/mock/route_param_manager.go b/routing/mock/route_param_manager.go deleted file mode 100644 index 6276d22..0000000 --- a/routing/mock/route_param_manager.go +++ /dev/null @@ -1,34 +0,0 @@ -package mockrouting - -import ( - "net/http" - - "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - - "github.com/stretchr/testify/mock" -) - -// NewRouteParamManager returns a new RouteParamManager. -func NewRouteParamManager() *RouteParamManager { - return &RouteParamManager{} -} - -// RouteParamManager is a mock routing.RouteParamManager. -type RouteParamManager struct { - mock.Mock -} - -// UserIDFetcherFromSessionContextData satisfies our interface contract. -func (m *RouteParamManager) UserIDFetcherFromSessionContextData(req *http.Request) uint64 { - return m.Called(req).Get(0).(uint64) -} - -// BuildRouteParamIDFetcher satisfies our interface contract. -func (m *RouteParamManager) BuildRouteParamIDFetcher(logger logging.Logger, key, logDescription string) func(*http.Request) uint64 { - return m.Called(logger, key, logDescription).Get(0).(func(*http.Request) uint64) -} - -// BuildRouteParamStringIDFetcher satisfies our interface contract. -func (m *RouteParamManager) BuildRouteParamStringIDFetcher(key string) func(req *http.Request) string { - return m.Called(key).Get(0).(func(*http.Request) string) -} diff --git a/routing/mock/route_param_manager_mock.go b/routing/mock/route_param_manager_mock.go new file mode 100644 index 0000000..82e6fb2 --- /dev/null +++ b/routing/mock/route_param_manager_mock.go @@ -0,0 +1,134 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mockrouting + +import ( + "net/http" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" + "github.com/verygoodsoftwarenotvirus/platform/v5/routing" +) + +// Ensure, that RouteParamManagerMock does implement routing.RouteParamManager. +// If this is not the case, regenerate this file with moq. +var _ routing.RouteParamManager = &RouteParamManagerMock{} + +// RouteParamManagerMock is a mock implementation of routing.RouteParamManager. +// +// func TestSomethingThatUsesRouteParamManager(t *testing.T) { +// +// // make and configure a mocked routing.RouteParamManager +// mockedRouteParamManager := &RouteParamManagerMock{ +// BuildRouteParamIDFetcherFunc: func(logger logging.Logger, key string, logDescription string) func(req *http.Request) uint64 { +// panic("mock out the BuildRouteParamIDFetcher method") +// }, +// BuildRouteParamStringIDFetcherFunc: func(key string) func(req *http.Request) string { +// panic("mock out the BuildRouteParamStringIDFetcher method") +// }, +// } +// +// // use mockedRouteParamManager in code that requires routing.RouteParamManager +// // and then make assertions. +// +// } +type RouteParamManagerMock struct { + // BuildRouteParamIDFetcherFunc mocks the BuildRouteParamIDFetcher method. + BuildRouteParamIDFetcherFunc func(logger logging.Logger, key string, logDescription string) func(req *http.Request) uint64 + + // BuildRouteParamStringIDFetcherFunc mocks the BuildRouteParamStringIDFetcher method. + BuildRouteParamStringIDFetcherFunc func(key string) func(req *http.Request) string + + // calls tracks calls to the methods. + calls struct { + // BuildRouteParamIDFetcher holds details about calls to the BuildRouteParamIDFetcher method. + BuildRouteParamIDFetcher []struct { + // Logger is the logger argument value. + Logger logging.Logger + // Key is the key argument value. + Key string + // LogDescription is the logDescription argument value. + LogDescription string + } + // BuildRouteParamStringIDFetcher holds details about calls to the BuildRouteParamStringIDFetcher method. + BuildRouteParamStringIDFetcher []struct { + // Key is the key argument value. + Key string + } + } + lockBuildRouteParamIDFetcher sync.RWMutex + lockBuildRouteParamStringIDFetcher sync.RWMutex +} + +// BuildRouteParamIDFetcher calls BuildRouteParamIDFetcherFunc. +func (mock *RouteParamManagerMock) BuildRouteParamIDFetcher(logger logging.Logger, key string, logDescription string) func(req *http.Request) uint64 { + if mock.BuildRouteParamIDFetcherFunc == nil { + panic("RouteParamManagerMock.BuildRouteParamIDFetcherFunc: method is nil but RouteParamManager.BuildRouteParamIDFetcher was just called") + } + callInfo := struct { + Logger logging.Logger + Key string + LogDescription string + }{ + Logger: logger, + Key: key, + LogDescription: logDescription, + } + mock.lockBuildRouteParamIDFetcher.Lock() + mock.calls.BuildRouteParamIDFetcher = append(mock.calls.BuildRouteParamIDFetcher, callInfo) + mock.lockBuildRouteParamIDFetcher.Unlock() + return mock.BuildRouteParamIDFetcherFunc(logger, key, logDescription) +} + +// BuildRouteParamIDFetcherCalls gets all the calls that were made to BuildRouteParamIDFetcher. +// Check the length with: +// +// len(mockedRouteParamManager.BuildRouteParamIDFetcherCalls()) +func (mock *RouteParamManagerMock) BuildRouteParamIDFetcherCalls() []struct { + Logger logging.Logger + Key string + LogDescription string +} { + var calls []struct { + Logger logging.Logger + Key string + LogDescription string + } + mock.lockBuildRouteParamIDFetcher.RLock() + calls = mock.calls.BuildRouteParamIDFetcher + mock.lockBuildRouteParamIDFetcher.RUnlock() + return calls +} + +// BuildRouteParamStringIDFetcher calls BuildRouteParamStringIDFetcherFunc. +func (mock *RouteParamManagerMock) BuildRouteParamStringIDFetcher(key string) func(req *http.Request) string { + if mock.BuildRouteParamStringIDFetcherFunc == nil { + panic("RouteParamManagerMock.BuildRouteParamStringIDFetcherFunc: method is nil but RouteParamManager.BuildRouteParamStringIDFetcher was just called") + } + callInfo := struct { + Key string + }{ + Key: key, + } + mock.lockBuildRouteParamStringIDFetcher.Lock() + mock.calls.BuildRouteParamStringIDFetcher = append(mock.calls.BuildRouteParamStringIDFetcher, callInfo) + mock.lockBuildRouteParamStringIDFetcher.Unlock() + return mock.BuildRouteParamStringIDFetcherFunc(key) +} + +// BuildRouteParamStringIDFetcherCalls gets all the calls that were made to BuildRouteParamStringIDFetcher. +// Check the length with: +// +// len(mockedRouteParamManager.BuildRouteParamStringIDFetcherCalls()) +func (mock *RouteParamManagerMock) BuildRouteParamStringIDFetcherCalls() []struct { + Key string +} { + var calls []struct { + Key string + } + mock.lockBuildRouteParamStringIDFetcher.RLock() + calls = mock.calls.BuildRouteParamStringIDFetcher + mock.lockBuildRouteParamStringIDFetcher.RUnlock() + return calls +} diff --git a/search/text/algolia/index_test.go b/search/text/algolia/index_test.go index 41b3ad4..05aada2 100644 --- a/search/text/algolia/index_test.go +++ b/search/text/algolia/index_test.go @@ -16,7 +16,6 @@ import ( algoliasearch "github.com/algolia/algoliasearch-client-go/v3/algolia/search" algoliatransport "github.com/algolia/algoliasearch-client-go/v3/algolia/transport" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) var _ algoliatransport.Requester = (*testRequester)(nil) @@ -116,16 +115,15 @@ func TestIndexManager_Index(T *testing.T) { T.Run("with broken circuit breaker", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) assert.Error(t, err) assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with unmarshalable value", func(t *testing.T) { @@ -163,15 +161,14 @@ func TestIndexManager_Index(T *testing.T) { _, _ = w.Write([]byte(`{"createdAt":"2021-01-01T00:00:00Z","objectID":"123","taskID":123}`)) }) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + } im := buildTestIndexManagerWithMockServer(t, handler, cb) err := im.Index(context.Background(), "123", map[string]string{"id": "123", "name": "example"}) assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -181,8 +178,9 @@ func TestIndexManager_Search(T *testing.T) { T.Run("with broken circuit breaker", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } im := buildTestIndexManagerWithCircuitBreaker(t, cb) @@ -190,15 +188,14 @@ func TestIndexManager_Search(T *testing.T) { assert.Error(t, err) assert.Nil(t, results) assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with empty query", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + } im := buildTestIndexManagerWithCircuitBreaker(t, cb) @@ -206,24 +203,21 @@ func TestIndexManager_Search(T *testing.T) { assert.Error(t, err) assert.Nil(t, results) assert.Equal(t, ErrEmptyQueryProvided, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with valid query but invalid credentials", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithCircuitBreaker(t, cb) results, err := im.Search(context.Background(), "test query") assert.Error(t, err) assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with successful search results", func(t *testing.T) { @@ -234,9 +228,10 @@ func TestIndexManager_Search(T *testing.T) { _, _ = w.Write([]byte(`{"hits":[{"objectID":"123"}],"nbHits":1,"page":0,"nbPages":1,"hitsPerPage":20,"processingTimeMS":1}`)) }) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithMockServer(t, handler, cb) @@ -244,8 +239,6 @@ func TestIndexManager_Search(T *testing.T) { assert.NoError(t, err) assert.NotNil(t, results) assert.Len(t, results, 1) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with empty search results", func(t *testing.T) { @@ -256,9 +249,10 @@ func TestIndexManager_Search(T *testing.T) { _, _ = w.Write([]byte(`{"hits":[],"nbHits":0,"page":0,"nbPages":0,"hitsPerPage":20,"processingTimeMS":1}`)) }) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithMockServer(t, handler, cb) @@ -266,8 +260,6 @@ func TestIndexManager_Search(T *testing.T) { assert.NoError(t, err) assert.NotNil(t, results) assert.Empty(t, results) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with multiple search results", func(t *testing.T) { @@ -278,9 +270,10 @@ func TestIndexManager_Search(T *testing.T) { _, _ = w.Write([]byte(`{"hits":[{"objectID":"abc","name":"first"},{"objectID":"def","name":"second"},{"objectID":"ghi","name":"third"}],"nbHits":3,"page":0,"nbPages":1,"hitsPerPage":20,"processingTimeMS":1}`)) }) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithMockServer(t, handler, cb) @@ -293,8 +286,6 @@ func TestIndexManager_Search(T *testing.T) { assert.Equal(t, "second", results[1].Name) assert.Equal(t, "ghi", results[2].ID) assert.Equal(t, "third", results[2].Name) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("when unmarshalling search result fails", func(t *testing.T) { @@ -305,16 +296,15 @@ func TestIndexManager_Search(T *testing.T) { _, _ = w.Write([]byte(`{"hits":[{"objectID":"123","name":["not","a","string"]}],"nbHits":1,"page":0,"nbPages":1,"hitsPerPage":20,"processingTimeMS":1}`)) }) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + } im := buildTestIndexManagerWithMockServer(t, handler, cb) results, err := im.Search(context.Background(), "test query") assert.Error(t, err) assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with successful search results without objectID", func(t *testing.T) { @@ -325,9 +315,10 @@ func TestIndexManager_Search(T *testing.T) { _, _ = w.Write([]byte(`{"hits":[{"name":"example"}],"nbHits":1,"page":0,"nbPages":1,"hitsPerPage":20,"processingTimeMS":1}`)) }) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithMockServer(t, handler, cb) @@ -335,8 +326,6 @@ func TestIndexManager_Search(T *testing.T) { assert.NoError(t, err) assert.NotNil(t, results) assert.Len(t, results, 1) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -346,31 +335,29 @@ func TestIndexManager_Delete(T *testing.T) { T.Run("with broken circuit breaker", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Delete(context.Background(), "id") assert.Error(t, err) assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with invalid credentials", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Delete(context.Background(), "some-id") assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with successful deletion", func(t *testing.T) { @@ -381,16 +368,15 @@ func TestIndexManager_Delete(T *testing.T) { _, _ = w.Write([]byte(`{"deletedAt":"2021-01-01T00:00:00Z","taskID":123}`)) }) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithMockServer(t, handler, cb) err := im.Delete(context.Background(), "some-id") assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -400,31 +386,29 @@ func TestIndexManager_Wipe(T *testing.T) { T.Run("with broken circuit breaker", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Wipe(context.Background()) assert.Error(t, err) assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with invalid credentials", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Wipe(context.Background()) assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with successful wipe", func(t *testing.T) { @@ -435,15 +419,14 @@ func TestIndexManager_Wipe(T *testing.T) { _, _ = w.Write([]byte(`{"updatedAt":"2021-01-01T00:00:00Z","taskID":123}`)) }) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithMockServer(t, handler, cb) err := im.Wipe(context.Background()) assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } diff --git a/search/text/config/config_test.go b/search/text/config/config_test.go index 03f9317..61221a7 100644 --- a/search/text/config/config_test.go +++ b/search/text/config/config_test.go @@ -13,6 +13,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/search/text/algolia" "github.com/verygoodsoftwarenotvirus/platform/v5/search/text/elasticsearch" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel/metric" ) @@ -291,11 +292,14 @@ func TestConfig_ProvideIndex(T *testing.T) { }, } - mp := &mockmetrics.MetricsProvider{} // Force the very first counter creation to fail so ProvideCircuitBreaker // returns an error, which is wrapped by ProvideIndex. - mp.On("NewInt64Counter", "test-breaker_circuit_breaker_tripped", []metric.Int64CounterOption(nil)). - Return(&mockmetrics.Int64Counter{}, errors.New("counter init failure")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "test-breaker_circuit_breaker_tripped", counterName) + return &mockmetrics.Int64CounterMock{}, errors.New("counter init failure") + }, + } logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() @@ -303,7 +307,8 @@ func TestConfig_ProvideIndex(T *testing.T) { assert.Error(t, err) assert.Nil(t, index) assert.Contains(t, err.Error(), "circuit breaker") - mp.AssertExpectations(t) + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/search/text/elasticsearch/elasticsearch_test.go b/search/text/elasticsearch/elasticsearch_test.go index a3a5473..6aa0577 100644 --- a/search/text/elasticsearch/elasticsearch_test.go +++ b/search/text/elasticsearch/elasticsearch_test.go @@ -18,7 +18,6 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" elasticsearchcontainers "github.com/testcontainers/testcontainers-go/modules/elasticsearch" ) @@ -328,31 +327,29 @@ func TestIndexManager_ensureIndices_CircuitBroken(T *testing.T) { T.Run("with broken circuit breaker", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } im := buildTestIndexManagerForUnit(t, cb) err := im.ensureIndices(context.Background()) assert.Error(t, err) assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with unreachable server", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerForUnit(t, cb) err := im.ensureIndices(context.Background()) assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -372,16 +369,15 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) err := im.ensureIndices(context.Background()) assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("index does not exist and create succeeds", func(t *testing.T) { @@ -402,16 +398,15 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) err := im.ensureIndices(context.Background()) assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("index does not exist and create fails", func(t *testing.T) { @@ -436,16 +431,15 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) err := im.ensureIndices(context.Background()) assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -553,15 +547,14 @@ func TestProvideIndexManager_Unit(T *testing.T) { logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) assert.NoError(t, err) assert.NotNil(t, im) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("fails when ensureIndices fails", func(t *testing.T) { @@ -604,14 +597,13 @@ func TestProvideIndexManager_Unit(T *testing.T) { logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) assert.Error(t, err) assert.Nil(t, im) - - mock.AssertExpectationsForObjects(t, cb) }) } diff --git a/search/text/elasticsearch/index_test.go b/search/text/elasticsearch/index_test.go index 316589c..cfb1983 100644 --- a/search/text/elasticsearch/index_test.go +++ b/search/text/elasticsearch/index_test.go @@ -14,7 +14,6 @@ import ( "github.com/elastic/go-elasticsearch/v8" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -71,45 +70,42 @@ func TestIndexManager_Index_CircuitBroken(T *testing.T) { T.Run("with broken circuit breaker", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } im := buildTestIndexManagerForUnit(t, cb) err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) assert.Error(t, err) assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with unmarshalable value", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + } im := buildTestIndexManagerForUnit(t, cb) err := im.Index(context.Background(), "id", make(chan int)) assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with unreachable server", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerForUnit(t, cb) err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -127,16 +123,15 @@ func TestIndexManager_Index_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with non-success status code", func(t *testing.T) { @@ -150,16 +145,15 @@ func TestIndexManager_Index_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -169,8 +163,9 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { T.Run("with broken circuit breaker", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } im := buildTestIndexManagerForUnit(t, cb) @@ -178,15 +173,14 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { assert.Error(t, err) assert.Nil(t, results) assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with empty query", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + } im := buildTestIndexManagerForUnit(t, cb) @@ -194,24 +188,21 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { assert.Error(t, err) assert.Nil(t, results) assert.Equal(t, ErrEmptyQueryProvided, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with unreachable server", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerForUnit(t, cb) results, err := im.Search(context.Background(), "test query") assert.Error(t, err) assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -229,9 +220,10 @@ func TestIndexManager_Search_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) @@ -240,8 +232,6 @@ func TestIndexManager_Search_Unit(T *testing.T) { require.Len(t, results, 1) assert.Equal(t, "123", results[0].ID) assert.Equal(t, "test", results[0].Name) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with error response", func(t *testing.T) { @@ -255,9 +245,10 @@ func TestIndexManager_Search_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) @@ -268,8 +259,6 @@ func TestIndexManager_Search_Unit(T *testing.T) { results, err := im.Search(context.Background(), "test") assert.NoError(t, err) assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with invalid JSON in success response", func(t *testing.T) { @@ -283,9 +272,10 @@ func TestIndexManager_Search_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) @@ -294,8 +284,6 @@ func TestIndexManager_Search_Unit(T *testing.T) { results, err := im.Search(context.Background(), "test") assert.NoError(t, err) assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -313,9 +301,10 @@ func TestIndexManager_Search_ErrorResponseDecodeFailure_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) @@ -324,8 +313,6 @@ func TestIndexManager_Search_ErrorResponseDecodeFailure_Unit(T *testing.T) { results, err := im.Search(context.Background(), "test") assert.NoError(t, err) assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -343,9 +330,10 @@ func TestIndexManager_Search_SourceUnmarshalError_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) @@ -354,8 +342,6 @@ func TestIndexManager_Search_SourceUnmarshalError_Unit(T *testing.T) { results, err := im.Search(context.Background(), "test") assert.NoError(t, err) assert.Nil(t, results) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -365,31 +351,29 @@ func TestIndexManager_Delete_CircuitBroken(T *testing.T) { T.Run("with broken circuit breaker", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } im := buildTestIndexManagerForUnit(t, cb) err := im.Delete(context.Background(), "id") assert.Error(t, err) assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) - - mock.AssertExpectationsForObjects(t, cb) }) T.Run("with unreachable server", func(t *testing.T) { t.Parallel() - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } im := buildTestIndexManagerForUnit(t, cb) err := im.Delete(context.Background(), "some-id") assert.Error(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } @@ -407,16 +391,15 @@ func TestIndexManager_Delete_Unit(T *testing.T) { })) t.Cleanup(server.Close) - cb := &mockcircuitbreaking.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &mockcircuitbreaking.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } im := buildTestIndexManagerWithServer(t, server, cb) err := im.Delete(context.Background(), "123") assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, cb) }) } diff --git a/search/text/indexing/do.go b/search/text/indexing/do.go index bb0309d..deb8487 100644 --- a/search/text/indexing/do.go +++ b/search/text/indexing/do.go @@ -16,7 +16,7 @@ import ( // Prerequisites: map[string]Function and *msgconfig.QueuesConfig must be // registered in the injector before calling this. func RegisterIndexScheduler(i do.Injector) { - do.Provide[*IndexScheduler](i, func(i do.Injector) (*IndexScheduler, error) { + do.Provide(i, func(i do.Injector) (*IndexScheduler, error) { return NewIndexScheduler( do.MustInvoke[context.Context](i), do.MustInvoke[logging.Logger](i), diff --git a/search/text/indexing/do_test.go b/search/text/indexing/do_test.go index f37f302..6df1cf7 100644 --- a/search/text/indexing/do_test.go +++ b/search/text/indexing/do_test.go @@ -11,11 +11,10 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/reflection" "github.com/samber/do/v2" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" otelmetric "go.opentelemetry.io/otel/metric" ) @@ -26,13 +25,20 @@ func TestRegisterIndexScheduler(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - metricsProvider := &mockmetrics.MetricsProvider{} - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []otelmetric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{} + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...otelmetric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } - messageQueueProvider := &mockpublishers.PublisherProvider{} - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), "test_topic").Return(publisher, nil) + publisher := &mockpublishers.PublisherMock{} + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } i := do.New() do.ProvideValue(i, t.Context()) @@ -53,6 +59,8 @@ func TestRegisterIndexScheduler(T *testing.T) { require.NoError(t, err) assert.NotNil(t, scheduler) - mock.AssertExpectationsForObjects(t, metricsProvider, messageQueueProvider) + test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) + test.SliceLen(t, 1, messageQueueProvider.ProvidePublisherCalls()) + test.EqOp(t, "test_topic", messageQueueProvider.ProvidePublisherCalls()[0].Topic) }) } diff --git a/search/text/indexing/indexer_test.go b/search/text/indexing/indexer_test.go index 464fec0..41b54bd 100644 --- a/search/text/indexing/indexer_test.go +++ b/search/text/indexing/indexer_test.go @@ -6,16 +6,17 @@ import ( "errors" "testing" + "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue" msgconfig "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue/config" mockpublishers "github.com/verygoodsoftwarenotvirus/platform/v5/messagequeue/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" + "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/reflection" textsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/text" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -31,16 +32,25 @@ func TestNewIndexScheduler(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + publisher := &mockpublishers.PublisherMock{} + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } indexFunctions := map[string]Function{ "test_type": func(ctx context.Context) ([]string, error) { @@ -54,7 +64,9 @@ func TestNewIndexScheduler(T *testing.T) { assert.Equal(t, []string{"test_type"}, scheduler.allIndexTypes) assert.Len(t, scheduler.indexFunctions, 1) - mock.AssertExpectationsForObjects(t, metricsProvider, messageQueueProvider) + test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) + test.SliceLen(t, 1, messageQueueProvider.ProvidePublisherCalls()) + test.EqOp(t, testQueuesConfig.SearchIndexRequestsTopicName, messageQueueProvider.ProvidePublisherCalls()[0].Topic) }) T.Run("with nil index functions", func(t *testing.T) { @@ -63,16 +75,25 @@ func TestNewIndexScheduler(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + publisher := &mockpublishers.PublisherMock{} + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, nil) assert.NoError(t, err) @@ -81,7 +102,8 @@ func TestNewIndexScheduler(T *testing.T) { assert.NotNil(t, scheduler.indexFunctions) assert.Len(t, scheduler.indexFunctions, 0) - mock.AssertExpectationsForObjects(t, metricsProvider, messageQueueProvider) + test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) + test.SliceLen(t, 1, messageQueueProvider.ProvidePublisherCalls()) }) T.Run("metrics provider error", func(t *testing.T) { @@ -90,19 +112,23 @@ func TestNewIndexScheduler(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} + messageQueueProvider := &mockpublishers.PublisherProviderMock{} // Mock metrics provider to return error - need to return a valid interface and error - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, errors.New("metrics error")) + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return &mockmetrics.Int64CounterMock{}, errors.New("metrics error") + }, + } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, nil) assert.Error(t, err) assert.Nil(t, scheduler) assert.Contains(t, err.Error(), "metrics error") - mock.AssertExpectationsForObjects(t, metricsProvider) + test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) + test.SliceLen(t, 0, messageQueueProvider.ProvidePublisherCalls()) }) T.Run("message queue provider error", func(t *testing.T) { @@ -111,23 +137,32 @@ func TestNewIndexScheduler(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } // Mock message queue provider to return error - need to return a valid interface and error - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, errors.New("message queue error")) + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return &mockpublishers.PublisherMock{}, errors.New("message queue error") + }, + } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, nil) assert.Error(t, err) assert.Nil(t, scheduler) assert.Contains(t, err.Error(), "message queue error") - mock.AssertExpectationsForObjects(t, metricsProvider, messageQueueProvider) + test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) + test.SliceLen(t, 1, messageQueueProvider.ProvidePublisherCalls()) }) } @@ -140,16 +175,32 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } - // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + // Mock message queue provider - all publishes succeed + publisher := &mockpublishers.PublisherMock{ + PublishFunc: func(_ context.Context, data any) error { + req, ok := data.(*textsearch.IndexRequest) + require.True(t, ok) + test.EqOp(t, "test_type", req.IndexType) + return nil + }, + } + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } // Mock index function indexFunctions := map[string]Function{ @@ -161,19 +212,16 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) require.NoError(t, err) - // Mock publisher calls - publisher.On(reflection.GetMethodName(publisher.Publish), mock.Anything, &textsearch.IndexRequest{RowID: "id1", IndexType: "test_type"}).Return(nil) - publisher.On(reflection.GetMethodName(publisher.Publish), mock.Anything, &textsearch.IndexRequest{RowID: "id2", IndexType: "test_type"}).Return(nil) - publisher.On(reflection.GetMethodName(publisher.Publish), mock.Anything, &textsearch.IndexRequest{RowID: "id3", IndexType: "test_type"}).Return(nil) - - // Mock metrics counter - int64Counter.On(reflection.GetMethodName(int64Counter.Add), mock.Anything, int64(3), mock.Anything).Return() - // Since we only have one index type, it will always be chosen err = scheduler.IndexTypes(ctx) assert.NoError(t, err) - mock.AssertExpectationsForObjects(t, publisher, int64Counter) + publishedIDs := collectPublishedRowIDs(t, publisher.PublishCalls()) + test.SliceContainsAll(t, publishedIDs, []string{"id1", "id2", "id3"}) + + addCalls := int64Counter.AddCalls() + test.SliceLen(t, 1, addCalls) + test.EqOp(t, int64(3), addCalls[0].Incr) }) T.Run("successful execution with empty results", func(t *testing.T) { @@ -182,16 +230,25 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } - // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + // Mock message queue provider - no Publish calls expected + publisher := &mockpublishers.PublisherMock{} + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } // Mock index function that returns empty results indexFunctions := map[string]Function{ @@ -205,12 +262,14 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { // No publisher calls expected for empty results // But metrics counter is still called with 0 - int64Counter.On(reflection.GetMethodName(int64Counter.Add), mock.Anything, int64(0), mock.Anything).Return() - err = scheduler.IndexTypes(ctx) assert.NoError(t, err) - mock.AssertExpectationsForObjects(t, publisher, int64Counter) + test.SliceLen(t, 0, publisher.PublishCalls()) + + addCalls := int64Counter.AddCalls() + test.SliceLen(t, 1, addCalls) + test.EqOp(t, int64(0), addCalls[0].Incr) }) T.Run("function returns sql.ErrNoRows", func(t *testing.T) { @@ -219,16 +278,25 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } - // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + // Mock message queue provider - no Publish calls expected + publisher := &mockpublishers.PublisherMock{} + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } // Mock index function that returns sql.ErrNoRows indexFunctions := map[string]Function{ @@ -244,7 +312,7 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { err = scheduler.IndexTypes(ctx) assert.NoError(t, err) - mock.AssertExpectationsForObjects(t, publisher, int64Counter) + test.SliceLen(t, 0, publisher.PublishCalls()) }) T.Run("function returns other error", func(t *testing.T) { @@ -253,16 +321,25 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } - // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + // Mock message queue provider - no Publish calls expected + publisher := &mockpublishers.PublisherMock{} + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } // Mock index function that returns an error indexFunctions := map[string]Function{ @@ -278,7 +355,7 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "database connection failed") - mock.AssertExpectationsForObjects(t, publisher, int64Counter) + test.SliceLen(t, 0, publisher.PublishCalls()) }) T.Run("unknown index type", func(t *testing.T) { @@ -287,16 +364,25 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } - // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + // Mock message queue provider - no Publish calls expected + publisher := &mockpublishers.PublisherMock{} + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } // Create scheduler with empty index functions scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, map[string]Function{}) @@ -310,7 +396,7 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "unknown index type unknown_type") - mock.AssertExpectationsForObjects(t, publisher, int64Counter) + test.SliceLen(t, 0, publisher.PublishCalls()) }) T.Run("partial publish failures", func(t *testing.T) { @@ -319,16 +405,36 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } - // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + // Mock message queue provider - id2 fails, id1 and id3 succeed + publishResults := map[string]error{ + "id1": nil, + "id2": errors.New("publish failed"), + "id3": nil, + } + publisher := &mockpublishers.PublisherMock{ + PublishFunc: func(_ context.Context, data any) error { + req, ok := data.(*textsearch.IndexRequest) + require.True(t, ok) + return publishResults[req.RowID] + }, + } + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } // Mock index function indexFunctions := map[string]Function{ @@ -340,18 +446,16 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) require.NoError(t, err) - // Mock publisher calls - some succeed, some fail - publisher.On(reflection.GetMethodName(publisher.Publish), mock.Anything, &textsearch.IndexRequest{RowID: "id1", IndexType: "test_type"}).Return(nil) - publisher.On(reflection.GetMethodName(publisher.Publish), mock.Anything, &textsearch.IndexRequest{RowID: "id2", IndexType: "test_type"}).Return(errors.New("publish failed")) - publisher.On(reflection.GetMethodName(publisher.Publish), mock.Anything, &textsearch.IndexRequest{RowID: "id3", IndexType: "test_type"}).Return(nil) - - // Mock metrics counter - should only count successful publishes - int64Counter.On(reflection.GetMethodName(int64Counter.Add), mock.Anything, int64(2), mock.Anything).Return() - err = scheduler.IndexTypes(ctx) assert.NoError(t, err) // Partial failures don't cause the method to return an error - mock.AssertExpectationsForObjects(t, publisher, int64Counter) + publishedIDs := collectPublishedRowIDs(t, publisher.PublishCalls()) + test.SliceContainsAll(t, publishedIDs, []string{"id1", "id2", "id3"}) + + // Metrics counter should only count successful publishes + addCalls := int64Counter.AddCalls() + test.SliceLen(t, 1, addCalls) + test.EqOp(t, int64(2), addCalls[0].Incr) }) T.Run("all publish failures", func(t *testing.T) { @@ -360,16 +464,29 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { ctx := t.Context() logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() - metricsProvider := &mockmetrics.MetricsProvider{} - messageQueueProvider := &mockpublishers.PublisherProvider{} // Mock metrics provider - int64Counter := &mockmetrics.Int64Counter{} - metricsProvider.On(reflection.GetMethodName(metricsProvider.NewInt64Counter), "indexer.handled_records", []metric.Int64CounterOption(nil)).Return(int64Counter, nil) + int64Counter := &mockmetrics.Int64CounterMock{ + AddFunc: func(_ context.Context, _ int64, _ ...metric.AddOption) {}, + } + metricsProvider := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "indexer.handled_records", counterName) + return int64Counter, nil + }, + } - // Mock message queue provider - publisher := &mockpublishers.Publisher{} - messageQueueProvider.On(reflection.GetMethodName(messageQueueProvider.ProvidePublisher), testQueuesConfig.SearchIndexRequestsTopicName).Return(publisher, nil) + // Mock message queue provider - all publishes fail + publisher := &mockpublishers.PublisherMock{ + PublishFunc: func(_ context.Context, _ any) error { + return errors.New("publish failed") + }, + } + messageQueueProvider := &mockpublishers.PublisherProviderMock{ + ProvidePublisherFunc: func(_ context.Context, _ string) (messagequeue.Publisher, error) { + return publisher, nil + }, + } // Mock index function indexFunctions := map[string]Function{ @@ -381,16 +498,29 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) require.NoError(t, err) - // Mock publisher calls - all fail - publisher.On(reflection.GetMethodName(publisher.Publish), mock.Anything, &textsearch.IndexRequest{RowID: "id1", IndexType: "test_type"}).Return(errors.New("publish failed")) - publisher.On(reflection.GetMethodName(publisher.Publish), mock.Anything, &textsearch.IndexRequest{RowID: "id2", IndexType: "test_type"}).Return(errors.New("publish failed")) - - // Mock metrics counter - should count 0 successful publishes - int64Counter.On(reflection.GetMethodName(int64Counter.Add), mock.Anything, int64(0), mock.Anything).Return() - err = scheduler.IndexTypes(ctx) assert.NoError(t, err) // Even all failures don't cause the method to return an error - mock.AssertExpectationsForObjects(t, publisher, int64Counter) + test.SliceLen(t, 2, publisher.PublishCalls()) + + // Metrics counter should count 0 successful publishes + addCalls := int64Counter.AddCalls() + test.SliceLen(t, 1, addCalls) + test.EqOp(t, int64(0), addCalls[0].Incr) }) } + +func collectPublishedRowIDs(t *testing.T, calls []struct { + Ctx context.Context + Data any +}, +) []string { + t.Helper() + ids := make([]string, 0, len(calls)) + for i := range calls { + req, ok := calls[i].Data.(*textsearch.IndexRequest) + require.True(t, ok) + ids = append(ids, req.RowID) + } + return ids +} diff --git a/search/text/mock/doc.go b/search/text/mock/doc.go index a79d82f..5aeb912 100644 --- a/search/text/mock/doc.go +++ b/search/text/mock/doc.go @@ -1,4 +1,8 @@ /* -Package mocksearch provides an interface-compatible search index mock +Package mocksearch provides moq-generated mocks for the search/text package. */ package mocksearch + +// Regenerate the moq mocks via `go generate ./search/text/mock/`. + +//go:generate go tool github.com/matryer/moq -out index_mock.go -pkg mocksearch -rm -fmt goimports .. Index:IndexMock diff --git a/search/text/mock/index_manager.go b/search/text/mock/index_manager.go deleted file mode 100644 index 0b0fd96..0000000 --- a/search/text/mock/index_manager.go +++ /dev/null @@ -1,38 +0,0 @@ -package mocksearch - -import ( - "context" - - textsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/text" - - "github.com/stretchr/testify/mock" -) - -var ( - _ textsearch.Index[any] = (*IndexManager[any])(nil) -) - -// IndexManager is a mock IndexManager. -type IndexManager[T any] struct { - mock.Mock -} - -func (m *IndexManager[T]) Wipe(ctx context.Context) error { - return m.Called(ctx).Error(0) -} - -// Index implements our interface. -func (m *IndexManager[T]) Index(ctx context.Context, id string, value any) error { - return m.Called(ctx, id, value).Error(0) -} - -// Search implements our interface. -func (m *IndexManager[T]) Search(ctx context.Context, query string) (results []*T, err error) { - args := m.Called(ctx, query) - return args.Get(0).([]*T), args.Error(1) -} - -// Delete implements our interface. -func (m *IndexManager[T]) Delete(ctx context.Context, id string) error { - return m.Called(ctx, id).Error(0) -} diff --git a/search/text/mock/index_manager_test.go b/search/text/mock/index_manager_test.go deleted file mode 100644 index 3a4742a..0000000 --- a/search/text/mock/index_manager_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package mocksearch - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestIndexManager_Index(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &IndexManager[string]{} - m.On("Index", mock.Anything, "id", "value").Return(nil) - - err := m.Index(context.Background(), "id", "value") - assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, m) - }) -} - -func TestIndexManager_Search(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - expected := []*string{new(string), new(string)} - *expected[0] = "result1" - *expected[1] = "result2" - - m := &IndexManager[string]{} - m.On("Search", mock.Anything, "query").Return(expected, nil) - - results, err := m.Search(context.Background(), "query") - assert.NoError(t, err) - assert.Equal(t, expected, results) - - mock.AssertExpectationsForObjects(t, m) - }) -} - -func TestIndexManager_Delete(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &IndexManager[string]{} - m.On("Delete", mock.Anything, "id").Return(nil) - - err := m.Delete(context.Background(), "id") - assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, m) - }) -} - -func TestIndexManager_Wipe(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - m := &IndexManager[string]{} - m.On("Wipe", mock.Anything).Return(nil) - - err := m.Wipe(context.Background()) - assert.NoError(t, err) - - mock.AssertExpectationsForObjects(t, m) - }) -} diff --git a/search/text/mock/index_mock.go b/search/text/mock/index_mock.go new file mode 100644 index 0000000..31fa1b3 --- /dev/null +++ b/search/text/mock/index_mock.go @@ -0,0 +1,233 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mocksearch + +import ( + "context" + "sync" + + textsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/text" +) + +// Ensure, that IndexMock does implement textsearch.Index. +// If this is not the case, regenerate this file with moq. +var _ textsearch.Index[any] = &IndexMock[any]{} + +// IndexMock is a mock implementation of textsearch.Index. +// +// func TestSomethingThatUsesIndex(t *testing.T) { +// +// // make and configure a mocked textsearch.Index +// mockedIndex := &IndexMock{ +// DeleteFunc: func(ctx context.Context, id string) error { +// panic("mock out the Delete method") +// }, +// IndexFunc: func(ctx context.Context, id string, value any) error { +// panic("mock out the Index method") +// }, +// SearchFunc: func(ctx context.Context, query string) ([]*T, error) { +// panic("mock out the Search method") +// }, +// WipeFunc: func(ctx context.Context) error { +// panic("mock out the Wipe method") +// }, +// } +// +// // use mockedIndex in code that requires textsearch.Index +// // and then make assertions. +// +// } +type IndexMock[T any] struct { + // DeleteFunc mocks the Delete method. + DeleteFunc func(ctx context.Context, id string) error + + // IndexFunc mocks the Index method. + IndexFunc func(ctx context.Context, id string, value any) error + + // SearchFunc mocks the Search method. + SearchFunc func(ctx context.Context, query string) ([]*T, error) + + // WipeFunc mocks the Wipe method. + WipeFunc func(ctx context.Context) error + + // calls tracks calls to the methods. + calls struct { + // Delete holds details about calls to the Delete method. + Delete []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ID is the id argument value. + ID string + } + // Index holds details about calls to the Index method. + Index []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // ID is the id argument value. + ID string + // Value is the value argument value. + Value any + } + // Search holds details about calls to the Search method. + Search []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Query is the query argument value. + Query string + } + // Wipe holds details about calls to the Wipe method. + Wipe []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + } + lockDelete sync.RWMutex + lockIndex sync.RWMutex + lockSearch sync.RWMutex + lockWipe sync.RWMutex +} + +// Delete calls DeleteFunc. +func (mock *IndexMock[T]) Delete(ctx context.Context, id string) error { + if mock.DeleteFunc == nil { + panic("IndexMock.DeleteFunc: method is nil but Index.Delete was just called") + } + callInfo := struct { + Ctx context.Context + ID string + }{ + Ctx: ctx, + ID: id, + } + mock.lockDelete.Lock() + mock.calls.Delete = append(mock.calls.Delete, callInfo) + mock.lockDelete.Unlock() + return mock.DeleteFunc(ctx, id) +} + +// DeleteCalls gets all the calls that were made to Delete. +// Check the length with: +// +// len(mockedIndex.DeleteCalls()) +func (mock *IndexMock[T]) DeleteCalls() []struct { + Ctx context.Context + ID string +} { + var calls []struct { + Ctx context.Context + ID string + } + mock.lockDelete.RLock() + calls = mock.calls.Delete + mock.lockDelete.RUnlock() + return calls +} + +// Index calls IndexFunc. +func (mock *IndexMock[T]) Index(ctx context.Context, id string, value any) error { + if mock.IndexFunc == nil { + panic("IndexMock.IndexFunc: method is nil but Index.Index was just called") + } + callInfo := struct { + Ctx context.Context + ID string + Value any + }{ + Ctx: ctx, + ID: id, + Value: value, + } + mock.lockIndex.Lock() + mock.calls.Index = append(mock.calls.Index, callInfo) + mock.lockIndex.Unlock() + return mock.IndexFunc(ctx, id, value) +} + +// IndexCalls gets all the calls that were made to Index. +// Check the length with: +// +// len(mockedIndex.IndexCalls()) +func (mock *IndexMock[T]) IndexCalls() []struct { + Ctx context.Context + ID string + Value any +} { + var calls []struct { + Ctx context.Context + ID string + Value any + } + mock.lockIndex.RLock() + calls = mock.calls.Index + mock.lockIndex.RUnlock() + return calls +} + +// Search calls SearchFunc. +func (mock *IndexMock[T]) Search(ctx context.Context, query string) ([]*T, error) { + if mock.SearchFunc == nil { + panic("IndexMock.SearchFunc: method is nil but Index.Search was just called") + } + callInfo := struct { + Ctx context.Context + Query string + }{ + Ctx: ctx, + Query: query, + } + mock.lockSearch.Lock() + mock.calls.Search = append(mock.calls.Search, callInfo) + mock.lockSearch.Unlock() + return mock.SearchFunc(ctx, query) +} + +// SearchCalls gets all the calls that were made to Search. +// Check the length with: +// +// len(mockedIndex.SearchCalls()) +func (mock *IndexMock[T]) SearchCalls() []struct { + Ctx context.Context + Query string +} { + var calls []struct { + Ctx context.Context + Query string + } + mock.lockSearch.RLock() + calls = mock.calls.Search + mock.lockSearch.RUnlock() + return calls +} + +// Wipe calls WipeFunc. +func (mock *IndexMock[T]) Wipe(ctx context.Context) error { + if mock.WipeFunc == nil { + panic("IndexMock.WipeFunc: method is nil but Index.Wipe was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockWipe.Lock() + mock.calls.Wipe = append(mock.calls.Wipe, callInfo) + mock.lockWipe.Unlock() + return mock.WipeFunc(ctx) +} + +// WipeCalls gets all the calls that were made to Wipe. +// Check the length with: +// +// len(mockedIndex.WipeCalls()) +func (mock *IndexMock[T]) WipeCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockWipe.RLock() + calls = mock.calls.Wipe + mock.lockWipe.RUnlock() + return calls +} diff --git a/search/vector/config/config_test.go b/search/vector/config/config_test.go index 89c0545..88205ea 100644 --- a/search/vector/config/config_test.go +++ b/search/vector/config/config_test.go @@ -18,6 +18,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector/pgvector" "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector/qdrant" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" @@ -234,9 +235,12 @@ func TestConfig_ProvideIndex(T *testing.T) { }, } - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", "test-breaker_circuit_breaker_tripped", []metric.Int64CounterOption(nil)). - Return(&mockmetrics.Int64Counter{}, errors.New("counter init failure")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, "test-breaker_circuit_breaker_tripped", counterName) + return &mockmetrics.Int64CounterMock{}, errors.New("counter init failure") + }, + } idx, err := ProvideIndex[testStruct]( ctx, @@ -250,6 +254,7 @@ func TestConfig_ProvideIndex(T *testing.T) { assert.Error(t, err) assert.Nil(t, idx) assert.Contains(t, err.Error(), "circuit breaker") - mp.AssertExpectations(t) + + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/search/vector/mock/doc.go b/search/vector/mock/doc.go new file mode 100644 index 0000000..46710b9 --- /dev/null +++ b/search/vector/mock/doc.go @@ -0,0 +1,6 @@ +// Package mock provides moq-generated mocks for the search/vector package. +package mock + +// Regenerate the moq mocks via `go generate ./search/vector/mock/`. + +//go:generate go tool github.com/matryer/moq -out index_mock.go -pkg mock -rm -fmt goimports .. Index:IndexMock diff --git a/search/vector/mock/index.go b/search/vector/mock/index.go deleted file mode 100644 index f590bd3..0000000 --- a/search/vector/mock/index.go +++ /dev/null @@ -1,40 +0,0 @@ -package mock - -import ( - "context" - - vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" - - "github.com/stretchr/testify/mock" -) - -var _ vectorsearch.Index[any] = (*Index[any])(nil) - -// Index is a testify-backed mock of vectorsearch.Index. -type Index[T any] struct { - mock.Mock -} - -// Upsert implements the vectorsearch.Index interface. -func (m *Index[T]) Upsert(ctx context.Context, vectors ...vectorsearch.Vector[T]) error { - return m.Called(ctx, vectors).Error(0) -} - -// Delete implements the vectorsearch.Index interface. -func (m *Index[T]) Delete(ctx context.Context, ids ...string) error { - return m.Called(ctx, ids).Error(0) -} - -// Wipe implements the vectorsearch.Index interface. -func (m *Index[T]) Wipe(ctx context.Context) error { - return m.Called(ctx).Error(0) -} - -// Query implements the vectorsearch.Index interface. -func (m *Index[T]) Query(ctx context.Context, req vectorsearch.QueryRequest) ([]vectorsearch.QueryResult[T], error) { - args := m.Called(ctx, req) - if v := args.Get(0); v != nil { - return v.([]vectorsearch.QueryResult[T]), args.Error(1) - } - return nil, args.Error(1) -} diff --git a/search/vector/mock/index_mock.go b/search/vector/mock/index_mock.go new file mode 100644 index 0000000..3c2abcb --- /dev/null +++ b/search/vector/mock/index_mock.go @@ -0,0 +1,227 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mock + +import ( + "context" + "sync" + + vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" +) + +// Ensure, that IndexMock does implement vectorsearch.Index. +// If this is not the case, regenerate this file with moq. +var _ vectorsearch.Index[any] = &IndexMock[any]{} + +// IndexMock is a mock implementation of vectorsearch.Index. +// +// func TestSomethingThatUsesIndex(t *testing.T) { +// +// // make and configure a mocked vectorsearch.Index +// mockedIndex := &IndexMock{ +// DeleteFunc: func(ctx context.Context, ids ...string) error { +// panic("mock out the Delete method") +// }, +// QueryFunc: func(ctx context.Context, req vectorsearch.QueryRequest) ([]vectorsearch.QueryResult[T], error) { +// panic("mock out the Query method") +// }, +// UpsertFunc: func(ctx context.Context, vectors ...vectorsearch.Vector[T]) error { +// panic("mock out the Upsert method") +// }, +// WipeFunc: func(ctx context.Context) error { +// panic("mock out the Wipe method") +// }, +// } +// +// // use mockedIndex in code that requires vectorsearch.Index +// // and then make assertions. +// +// } +type IndexMock[T any] struct { + // DeleteFunc mocks the Delete method. + DeleteFunc func(ctx context.Context, ids ...string) error + + // QueryFunc mocks the Query method. + QueryFunc func(ctx context.Context, req vectorsearch.QueryRequest) ([]vectorsearch.QueryResult[T], error) + + // UpsertFunc mocks the Upsert method. + UpsertFunc func(ctx context.Context, vectors ...vectorsearch.Vector[T]) error + + // WipeFunc mocks the Wipe method. + WipeFunc func(ctx context.Context) error + + // calls tracks calls to the methods. + calls struct { + // Delete holds details about calls to the Delete method. + Delete []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Ids is the ids argument value. + Ids []string + } + // Query holds details about calls to the Query method. + Query []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Req is the req argument value. + Req vectorsearch.QueryRequest + } + // Upsert holds details about calls to the Upsert method. + Upsert []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Vectors is the vectors argument value. + Vectors []vectorsearch.Vector[T] + } + // Wipe holds details about calls to the Wipe method. + Wipe []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } + } + lockDelete sync.RWMutex + lockQuery sync.RWMutex + lockUpsert sync.RWMutex + lockWipe sync.RWMutex +} + +// Delete calls DeleteFunc. +func (mock *IndexMock[T]) Delete(ctx context.Context, ids ...string) error { + if mock.DeleteFunc == nil { + panic("IndexMock.DeleteFunc: method is nil but Index.Delete was just called") + } + callInfo := struct { + Ctx context.Context + Ids []string + }{ + Ctx: ctx, + Ids: ids, + } + mock.lockDelete.Lock() + mock.calls.Delete = append(mock.calls.Delete, callInfo) + mock.lockDelete.Unlock() + return mock.DeleteFunc(ctx, ids...) +} + +// DeleteCalls gets all the calls that were made to Delete. +// Check the length with: +// +// len(mockedIndex.DeleteCalls()) +func (mock *IndexMock[T]) DeleteCalls() []struct { + Ctx context.Context + Ids []string +} { + var calls []struct { + Ctx context.Context + Ids []string + } + mock.lockDelete.RLock() + calls = mock.calls.Delete + mock.lockDelete.RUnlock() + return calls +} + +// Query calls QueryFunc. +func (mock *IndexMock[T]) Query(ctx context.Context, req vectorsearch.QueryRequest) ([]vectorsearch.QueryResult[T], error) { + if mock.QueryFunc == nil { + panic("IndexMock.QueryFunc: method is nil but Index.Query was just called") + } + callInfo := struct { + Ctx context.Context + Req vectorsearch.QueryRequest + }{ + Ctx: ctx, + Req: req, + } + mock.lockQuery.Lock() + mock.calls.Query = append(mock.calls.Query, callInfo) + mock.lockQuery.Unlock() + return mock.QueryFunc(ctx, req) +} + +// QueryCalls gets all the calls that were made to Query. +// Check the length with: +// +// len(mockedIndex.QueryCalls()) +func (mock *IndexMock[T]) QueryCalls() []struct { + Ctx context.Context + Req vectorsearch.QueryRequest +} { + var calls []struct { + Ctx context.Context + Req vectorsearch.QueryRequest + } + mock.lockQuery.RLock() + calls = mock.calls.Query + mock.lockQuery.RUnlock() + return calls +} + +// Upsert calls UpsertFunc. +func (mock *IndexMock[T]) Upsert(ctx context.Context, vectors ...vectorsearch.Vector[T]) error { + if mock.UpsertFunc == nil { + panic("IndexMock.UpsertFunc: method is nil but Index.Upsert was just called") + } + callInfo := struct { + Ctx context.Context + Vectors []vectorsearch.Vector[T] + }{ + Ctx: ctx, + Vectors: vectors, + } + mock.lockUpsert.Lock() + mock.calls.Upsert = append(mock.calls.Upsert, callInfo) + mock.lockUpsert.Unlock() + return mock.UpsertFunc(ctx, vectors...) +} + +// UpsertCalls gets all the calls that were made to Upsert. +// Check the length with: +// +// len(mockedIndex.UpsertCalls()) +func (mock *IndexMock[T]) UpsertCalls() []struct { + Ctx context.Context + Vectors []vectorsearch.Vector[T] +} { + var calls []struct { + Ctx context.Context + Vectors []vectorsearch.Vector[T] + } + mock.lockUpsert.RLock() + calls = mock.calls.Upsert + mock.lockUpsert.RUnlock() + return calls +} + +// Wipe calls WipeFunc. +func (mock *IndexMock[T]) Wipe(ctx context.Context) error { + if mock.WipeFunc == nil { + panic("IndexMock.WipeFunc: method is nil but Index.Wipe was just called") + } + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockWipe.Lock() + mock.calls.Wipe = append(mock.calls.Wipe, callInfo) + mock.lockWipe.Unlock() + return mock.WipeFunc(ctx) +} + +// WipeCalls gets all the calls that were made to Wipe. +// Check the length with: +// +// len(mockedIndex.WipeCalls()) +func (mock *IndexMock[T]) WipeCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockWipe.RLock() + calls = mock.calls.Wipe + mock.lockWipe.RUnlock() + return calls +} diff --git a/search/vector/pgvector/pgvector_test.go b/search/vector/pgvector/pgvector_test.go index 017d7fa..244e67f 100644 --- a/search/vector/pgvector/pgvector_test.go +++ b/search/vector/pgvector/pgvector_test.go @@ -13,18 +13,41 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/database" "github.com/verygoodsoftwarenotvirus/platform/v5/identifiers" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" postgrescontainer "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" + "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) +// counterResult bundles the values a mocked NewInt64Counter call returns. +type counterResult struct { + counter metrics.Int64Counter + err error +} + +// newCounterProviderMock returns a metrics.Provider mock whose NewInt64Counter +// implementation looks up the result keyed on the counter name. Unknown names +// fail the test. +func newCounterProviderMock(t *testing.T, results map[string]counterResult) *mockmetrics.ProviderMock { + t.Helper() + return &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(name string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + res, ok := results[name] + if !ok { + t.Fatalf("unexpected NewInt64Counter call: %q", name) + } + return res.counter, res.err + }, + } +} + const pgvectorImage = "pgvector/pgvector:pg16" var runningContainerTests = strings.ToLower(os.Getenv("RUN_CONTAINER_TESTS")) == "true" @@ -78,12 +101,12 @@ type doc struct { Title string `json:"title"` } -func provideTestIndex(t *testing.T, client database.Client, indexName string, dim int, metric vectorsearch.DistanceMetric) vectorsearch.Index[doc] { +func provideTestIndex(t *testing.T, client database.Client, indexName string, dim int, distanceMetric vectorsearch.DistanceMetric) vectorsearch.Index[doc] { t.Helper() cfg := &Config{ Dimension: dim, - Metric: metric, + Metric: distanceMetric, } im, err := ProvideIndex[doc](t.Context(), nil, nil, nil, cfg, client, indexName, cbnoop.NewCircuitBreaker()) require.NoError(t, err) @@ -139,88 +162,82 @@ func TestProvideIndex(T *testing.T) { T.Run("error creating upsert counter", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "pgvector_index_upserts", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + "pgvector_index_upserts": {counter: metricnoop.Int64Counter{}, err: errors.New("forced error")}, + }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) require.Error(t, err) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("error creating delete counter", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "pgvector_index_upserts", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_deletes", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + "pgvector_index_upserts": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_deletes": {counter: metricnoop.Int64Counter{}, err: errors.New("forced error")}, + }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) require.Error(t, err) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("error creating wipe counter", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "pgvector_index_upserts", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_deletes", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_wipes", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + "pgvector_index_upserts": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_deletes": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_wipes": {counter: metricnoop.Int64Counter{}, err: errors.New("forced error")}, + }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) require.Error(t, err) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("error creating query counter", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "pgvector_index_upserts", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_deletes", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_wipes", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_queries", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + "pgvector_index_upserts": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_deletes": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_wipes": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_queries": {counter: metricnoop.Int64Counter{}, err: errors.New("forced error")}, + }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) require.Error(t, err) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("error creating error counter", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "pgvector_index_upserts", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_deletes", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_wipes", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_queries", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_errors", mock.Anything).Return(metricnoop.Int64Counter{}, errors.New("forced error")) + mp := newCounterProviderMock(t, map[string]counterResult{ + "pgvector_index_upserts": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_deletes": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_wipes": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_queries": {counter: metricnoop.Int64Counter{}}, + "pgvector_index_errors": {counter: metricnoop.Int64Counter{}, err: errors.New("forced error")}, + }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) require.Error(t, err) - - mock.AssertExpectationsForObjects(t, mp) }) T.Run("error creating latency histogram", func(t *testing.T) { t.Parallel() - mp := &metrics.MockProvider{} - mp.On("NewInt64Counter", "pgvector_index_upserts", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_deletes", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_wipes", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_queries", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewInt64Counter", "pgvector_index_errors", mock.Anything).Return(metricnoop.Int64Counter{}, nil) - mp.On("NewFloat64Histogram", "pgvector_index_latency_ms", mock.Anything).Return(metricnoop.Float64Histogram{}, errors.New("forced error")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(string, ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metricnoop.Int64Counter{}, nil + }, + NewFloat64HistogramFunc: func(string, ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + return metricnoop.Float64Histogram{}, errors.New("forced error") + }, + } _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) require.Error(t, err) - - mock.AssertExpectationsForObjects(t, mp) }) } diff --git a/search/vector/qdrant/qdrant_test.go b/search/vector/qdrant/qdrant_test.go index 6a02d30..b583394 100644 --- a/search/vector/qdrant/qdrant_test.go +++ b/search/vector/qdrant/qdrant_test.go @@ -20,7 +20,6 @@ import ( vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -488,15 +487,16 @@ func TestUpsert(T *testing.T) { T.Run("circuit breaker broken", func(t *testing.T) { t.Parallel() - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true).Once() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } idx := buildStubIndex(t, &qdrantStub{}, cb) err := idx.Upsert(t.Context(), vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}}) require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - mock.AssertExpectationsForObjects(t, cb) + require.Len(t, cb.CannotProceedCalls(), 1) }) T.Run("rejects empty ID", func(t *testing.T) { @@ -568,15 +568,16 @@ func TestDelete(T *testing.T) { T.Run("circuit breaker broken", func(t *testing.T) { t.Parallel() - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true).Once() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } idx := buildStubIndex(t, &qdrantStub{}, cb) err := idx.Delete(t.Context(), "some-id") require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - mock.AssertExpectationsForObjects(t, cb) + require.Len(t, cb.CannotProceedCalls(), 1) }) T.Run("successful delete", func(t *testing.T) { @@ -616,15 +617,16 @@ func TestWipe(T *testing.T) { T.Run("circuit breaker broken", func(t *testing.T) { t.Parallel() - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true).Once() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } idx := buildStubIndex(t, &qdrantStub{}, cb) err := idx.Wipe(t.Context()) require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - mock.AssertExpectationsForObjects(t, cb) + require.Len(t, cb.CannotProceedCalls(), 1) }) T.Run("successful wipe", func(t *testing.T) { @@ -725,15 +727,16 @@ func TestQuery(T *testing.T) { T.Run("circuit breaker broken", func(t *testing.T) { t.Parallel() - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true).Once() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } idx := buildStubIndex(t, &qdrantStub{}, cb) _, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5}) require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - mock.AssertExpectationsForObjects(t, cb) + require.Len(t, cb.CannotProceedCalls(), 1) }) T.Run("defaults TopK to 10", func(t *testing.T) { diff --git a/secrets/config/config_test.go b/secrets/config/config_test.go index 6e9bc36..eae819f 100644 --- a/secrets/config/config_test.go +++ b/secrets/config/config_test.go @@ -16,8 +16,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" awsssm "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" corev1 "k8s.io/api/core/v1" @@ -284,36 +284,45 @@ func TestConfig_ProvideSecretSource(T *testing.T) { T.Run("nil config with metrics error", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.AnythingOfType("string"), []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } var cfg *Config source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("env provider with metrics error", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.AnythingOfType("string"), []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } cfg := &Config{Provider: ProviderEnv} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("gcp provider with metrics error", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.AnythingOfType("string"), []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } cfg := &Config{ Provider: ProviderGCP, @@ -324,14 +333,17 @@ func TestConfig_ProvideSecretSource(T *testing.T) { require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("ssm provider with metrics error", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.AnythingOfType("string"), []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } cfg := &Config{ Provider: ProviderSSM, @@ -342,14 +354,17 @@ func TestConfig_ProvideSecretSource(T *testing.T) { require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("kubectl provider with metrics error", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", mock.AnythingOfType("string"), []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } cfg := &Config{ Provider: ProviderKubectl, @@ -360,6 +375,6 @@ func TestConfig_ProvideSecretSource(T *testing.T) { require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/secrets/env/env_test.go b/secrets/env/env_test.go index 3c326ab..d7cbdee 100644 --- a/secrets/env/env_test.go +++ b/secrets/env/env_test.go @@ -10,8 +10,8 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/secrets" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -24,14 +24,18 @@ func TestNewEnvSecretSource(T *testing.T) { T.Run("with error creating lookup counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_lookups", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } source, err := NewEnvSecretSource(nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating latency histogram", func(t *testing.T) { @@ -41,15 +45,22 @@ func TestNewEnvSecretSource(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") require.NoError(t, histErr) - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewFloat64Histogram", name+"_latency_ms", []metric.Float64HistogramOption(nil)).Return(h, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(histName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + test.EqOp(t, name+"_latency_ms", histName) + return h, errors.New("arbitrary") + }, + } source, err := NewEnvSecretSource(nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/secrets/gcp/gcp_test.go b/secrets/gcp/gcp_test.go index 20ff603..a2f0be0 100644 --- a/secrets/gcp/gcp_test.go +++ b/secrets/gcp/gcp_test.go @@ -9,8 +9,8 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "cloud.google.com/go/secretmanager/apiv1/secretmanagerpb" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -47,30 +47,43 @@ func TestNewGCPSecretSource(T *testing.T) { T.Run("with error creating lookup counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_lookups", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } cfg := &Config{ProjectID: "test-project"} source, err := NewGCPSecretSource(context.Background(), cfg, &mockGCPClient{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case name + "_lookups": + return metrics.Int64CounterForTest(t, "x"), nil + case name + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } cfg := &Config{ProjectID: "test-project"} source, err := NewGCPSecretSource(context.Background(), cfg, &mockGCPClient{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error creating latency histogram", func(t *testing.T) { @@ -80,17 +93,23 @@ func TestNewGCPSecretSource(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") require.NoError(t, histErr) - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewFloat64Histogram", name+"_latency_ms", []metric.Float64HistogramOption(nil)).Return(h, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(histName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + test.EqOp(t, name+"_latency_ms", histName) + return h, errors.New("arbitrary") + }, + } cfg := &Config{ProjectID: "test-project"} source, err := NewGCPSecretSource(context.Background(), cfg, &mockGCPClient{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/secrets/kubectl/kubectl_test.go b/secrets/kubectl/kubectl_test.go index 2f83c21..91c6ef3 100644 --- a/secrets/kubectl/kubectl_test.go +++ b/secrets/kubectl/kubectl_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" corev1 "k8s.io/api/core/v1" @@ -47,30 +47,43 @@ func TestNewKubectlSecretSource(T *testing.T) { T.Run("with error creating lookup counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_lookups", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } cfg := &Config{Namespace: "default"} source, err := NewKubectlSecretSource(context.Background(), cfg, &mockSecretGetter{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case name + "_lookups": + return metrics.Int64CounterForTest(t, "x"), nil + case name + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } cfg := &Config{Namespace: "default"} source, err := NewKubectlSecretSource(context.Background(), cfg, &mockSecretGetter{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error creating latency histogram", func(t *testing.T) { @@ -80,17 +93,23 @@ func TestNewKubectlSecretSource(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") require.NoError(t, histErr) - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewFloat64Histogram", name+"_latency_ms", []metric.Float64HistogramOption(nil)).Return(h, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(histName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + test.EqOp(t, name+"_latency_ms", histName) + return h, errors.New("arbitrary") + }, + } cfg := &Config{Namespace: "default"} source, err := NewKubectlSecretSource(context.Background(), cfg, &mockSecretGetter{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/secrets/ssm/ssm_test.go b/secrets/ssm/ssm_test.go index eb08e6c..1fb16e2 100644 --- a/secrets/ssm/ssm_test.go +++ b/secrets/ssm/ssm_test.go @@ -11,8 +11,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/shoenig/test" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric" ) @@ -48,30 +48,43 @@ func TestNewSSMSecretSource(T *testing.T) { T.Run("with error creating lookup counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + test.EqOp(t, name+"_lookups", counterName) + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + }, + } cfg := &Config{Region: "us-east-1"} source, err := NewSSMSecretSource(context.Background(), cfg, &mockSSMClient{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("with error creating error counter", func(t *testing.T) { t.Parallel() - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + switch counterName { + case name + "_lookups": + return metrics.Int64CounterForTest(t, "x"), nil + case name + "_errors": + return metrics.Int64CounterForTest(t, "x"), errors.New("arbitrary") + } + t.Fatalf("unexpected NewInt64Counter call: %q", counterName) + return nil, nil + }, + } cfg := &Config{Region: "us-east-1"} source, err := NewSSMSecretSource(context.Background(), cfg, &mockSSMClient{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("with error creating latency histogram", func(t *testing.T) { @@ -81,17 +94,23 @@ func TestNewSSMSecretSource(T *testing.T) { h, histErr := noopMP.NewFloat64Histogram("test") require.NoError(t, histErr) - mp := &mockmetrics.MetricsProvider{} - mp.On("NewInt64Counter", name+"_lookups", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewInt64Counter", name+"_errors", []metric.Int64CounterOption(nil)).Return(metrics.Int64CounterForTest(t, "x"), nil) - mp.On("NewFloat64Histogram", name+"_latency_ms", []metric.Float64HistogramOption(nil)).Return(h, errors.New("arbitrary")) + mp := &mockmetrics.ProviderMock{ + NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { + return metrics.Int64CounterForTest(t, "x"), nil + }, + NewFloat64HistogramFunc: func(histName string, _ ...metric.Float64HistogramOption) (metrics.Float64Histogram, error) { + test.EqOp(t, name+"_latency_ms", histName) + return h, errors.New("arbitrary") + }, + } cfg := &Config{Region: "us-east-1"} source, err := NewSSMSecretSource(context.Background(), cfg, &mockSSMClient{}, nil, nil, mp) require.Error(t, err) assert.Nil(t, source) - mock.AssertExpectationsForObjects(t, mp) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/server/grpc/server_test.go b/server/grpc/server_test.go index 173e560..bcf4a5b 100644 --- a/server/grpc/server_test.go +++ b/server/grpc/server_test.go @@ -19,7 +19,6 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop" @@ -28,7 +27,8 @@ import ( type mockTracerProvider struct { noop.TracerProvider - mock.Mock + forceFlushFunc func(ctx context.Context) error + forceFlushCalls int } func (m *mockTracerProvider) Tracer(name string, opts ...trace.TracerOption) trace.Tracer { @@ -36,7 +36,11 @@ func (m *mockTracerProvider) Tracer(name string, opts ...trace.TracerOption) tra } func (m *mockTracerProvider) ForceFlush(ctx context.Context) error { - return m.Called(ctx).Error(0) + m.forceFlushCalls++ + if m.forceFlushFunc == nil { + return nil + } + return m.forceFlushFunc(ctx) } func generateTestTLSCerts(t *testing.T) (certFile, keyFile string) { @@ -215,8 +219,9 @@ func TestServer_Shutdown(T *testing.T) { T.Run("logs error when ForceFlush fails", func(t *testing.T) { t.Parallel() - mtp := &mockTracerProvider{} - mtp.On("ForceFlush", mock.Anything).Return(errors.New("flush failed")) + mtp := &mockTracerProvider{ + forceFlushFunc: func(_ context.Context) error { return errors.New("flush failed") }, + } cfg := &Config{Port: 0} srv, err := NewGRPCServer(cfg, logging.NewNoopLogger(), mtp, nil, nil) @@ -224,7 +229,7 @@ func TestServer_Shutdown(T *testing.T) { srv.Shutdown(context.Background()) - mock.AssertExpectationsForObjects(t, mtp) + assert.Equal(t, 1, mtp.forceFlushCalls) }) } diff --git a/server/http/http_server_test.go b/server/http/http_server_test.go index d7a307d..82e31f6 100644 --- a/server/http/http_server_test.go +++ b/server/http/http_server_test.go @@ -25,7 +25,6 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/routing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop" @@ -33,7 +32,8 @@ import ( type mockTracerProvider struct { noop.TracerProvider - mock.Mock + forceFlushFunc func(ctx context.Context) error + forceFlushCalls int } func (m *mockTracerProvider) Tracer(name string, opts ...trace.TracerOption) trace.Tracer { @@ -41,7 +41,11 @@ func (m *mockTracerProvider) Tracer(name string, opts ...trace.TracerOption) tra } func (m *mockTracerProvider) ForceFlush(ctx context.Context) error { - return m.Called(ctx).Error(0) + m.forceFlushCalls++ + if m.forceFlushFunc == nil { + return nil + } + return m.forceFlushFunc(ctx) } // stubRouter satisfies routing.Router for testing Serve(). @@ -209,8 +213,9 @@ func TestServer_Shutdown(T *testing.T) { T.Run("logs error when ForceFlush fails", func(t *testing.T) { t.Parallel() - mtp := &mockTracerProvider{} - mtp.On("ForceFlush", mock.Anything).Return(errors.New("flush failed")) + mtp := &mockTracerProvider{ + forceFlushFunc: func(_ context.Context) error { return errors.New("flush failed") }, + } s, err := ProvideHTTPServer(Config{Port: 0}, logging.NewNoopLogger(), nil, mtp, "") require.NoError(t, err) @@ -220,7 +225,7 @@ func TestServer_Shutdown(T *testing.T) { assert.NoError(t, s.Shutdown(ctx)) - mock.AssertExpectationsForObjects(t, mtp) + assert.Equal(t, 1, mtp.forceFlushCalls) }) } diff --git a/testutils/mock_handler.go b/testutils/mock_handler.go deleted file mode 100644 index 5b112d5..0000000 --- a/testutils/mock_handler.go +++ /dev/null @@ -1,19 +0,0 @@ -package testutils - -import ( - "net/http" - - "github.com/stretchr/testify/mock" -) - -var _ http.Handler = (*MockHTTPHandler)(nil) - -// MockHTTPHandler is a mocked http.Handler. -type MockHTTPHandler struct { - mock.Mock -} - -// ServeHTTP satisfies our interface requirements. -func (m *MockHTTPHandler) ServeHTTP(res http.ResponseWriter, req *http.Request) { - m.Called(res, req) -} diff --git a/testutils/mock_http_response_writer.go b/testutils/mock_http_response_writer.go deleted file mode 100644 index 040bbf6..0000000 --- a/testutils/mock_http_response_writer.go +++ /dev/null @@ -1,31 +0,0 @@ -package testutils - -import ( - "net/http" - - "github.com/stretchr/testify/mock" -) - -var _ http.ResponseWriter = (*MockHTTPResponseWriter)(nil) - -// MockHTTPResponseWriter is a mock http.ResponseWriter. -type MockHTTPResponseWriter struct { - mock.Mock -} - -// Header satisfies our interface requirements. -func (m *MockHTTPResponseWriter) Header() http.Header { - return m.Called().Get(0).(http.Header) -} - -// Write satisfies our interface requirements. -func (m *MockHTTPResponseWriter) Write(in []byte) (int, error) { - returnValues := m.Called(in) - - return returnValues.Int(0), returnValues.Error(1) -} - -// WriteHeader satisfies our interface requirements. -func (m *MockHTTPResponseWriter) WriteHeader(statusCode int) { - m.Called(statusCode) -} diff --git a/testutils/mock_io_read_closer.go b/testutils/mock_io_read_closer.go deleted file mode 100644 index 6f6de46..0000000 --- a/testutils/mock_io_read_closer.go +++ /dev/null @@ -1,26 +0,0 @@ -package testutils - -import ( - "io" - - "github.com/stretchr/testify/mock" -) - -var _ io.ReadCloser = (*MockReadCloser)(nil) - -// MockReadCloser mocks a io.ReadCloser. -type MockReadCloser struct { - mock.Mock -} - -// Read implements the io.ReadCloser interface. -func (m *MockReadCloser) Read(p []byte) (int, error) { - returnValues := m.Called(p) - - return returnValues.Int(0), returnValues.Error(1) -} - -// Close implements the io.ReadCloser interface. -func (m *MockReadCloser) Close() error { - return m.Called().Error(0) -} diff --git a/testutils/mock_io_writer.go b/testutils/mock_io_writer.go deleted file mode 100644 index c45d5a4..0000000 --- a/testutils/mock_io_writer.go +++ /dev/null @@ -1,20 +0,0 @@ -package testutils - -import ( - "io" - - "github.com/stretchr/testify/mock" -) - -var _ io.Writer = (*MockWriter)(nil) - -// MockWriter mocks a io.Writer. -type MockWriter struct { - mock.Mock -} - -// Write implements the io.Writer interface. -func (m *MockWriter) Write(p []byte) (int, error) { - returnVals := m.Called(p) - return returnVals.Int(0), returnVals.Error(1) -} diff --git a/testutils/mock_matchers.go b/testutils/mock_matchers.go deleted file mode 100644 index d63ca28..0000000 --- a/testutils/mock_matchers.go +++ /dev/null @@ -1,24 +0,0 @@ -package testutils - -import ( - "context" - "net/http" - - "github.com/verygoodsoftwarenotvirus/platform/v5/database/filtering" - - "github.com/stretchr/testify/mock" -) - -// MatchType is a matcher for use with testify/mock. -func MatchType[T any]() any { - return mock.MatchedBy(func(T) bool { - return true - }) -} - -var ( - ContextMatcher = MatchType[context.Context]() - QueryFilterMatcher = MatchType[*filtering.QueryFilter]() - HTTPRequestMatcher = MatchType[*http.Request]() - HTTPResponseWriterMatcher = MatchType[http.ResponseWriter]() -) diff --git a/uploads/images/images_test.go b/uploads/images/images_test.go index bb2687f..58be1ba 100644 --- a/uploads/images/images_test.go +++ b/uploads/images/images_test.go @@ -17,14 +17,26 @@ import ( "testing" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/verygoodsoftwarenotvirus/platform/v5/reflection" "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +// errorWriter is an http.ResponseWriter whose Write always returns an error. +type errorWriter struct { + header http.Header +} + +func (e *errorWriter) Header() http.Header { + if e.header == nil { + e.header = http.Header{} + } + return e.header +} +func (e *errorWriter) Write([]byte) (int, error) { return 0, errors.New("blah") } +func (e *errorWriter) WriteHeader(int) {} + func newAvatarUploadRequest(t *testing.T, filename string, avatar io.Reader) *http.Request { t.Helper() @@ -238,9 +250,7 @@ func TestImage_Write(T *testing.T) { Size: 12345, } - res := &testutils.MockHTTPResponseWriter{} - res.On(reflection.GetMethodName(res.Header)).Return(http.Header{}) - res.On(reflection.GetMethodName(res.Write), mock.IsType([]byte(nil))).Return(0, errors.New("blah")) + res := &errorWriter{} assert.Error(t, i.Write(res)) }) diff --git a/uploads/images/mock.go b/uploads/images/mock.go deleted file mode 100644 index 729568f..0000000 --- a/uploads/images/mock.go +++ /dev/null @@ -1,28 +0,0 @@ -package images - -import ( - "context" - "net/http" - - "github.com/stretchr/testify/mock" -) - -var _ MediaUploadProcessor = (*MockImageUploadProcessor)(nil) - -// MockImageUploadProcessor is a mock MediaUploadProcessor. -type MockImageUploadProcessor struct { - mock.Mock -} - -// ProcessFile satisfies the MediaUploadProcessor interface. -func (m *MockImageUploadProcessor) ProcessFile(ctx context.Context, req *http.Request, filename string) (*Upload, error) { - args := m.Called(ctx, req, filename) - - return args.Get(0).(*Upload), args.Error(1) -} - -func (m *MockImageUploadProcessor) ProcessFiles(ctx context.Context, req *http.Request, filenamePrefix string) ([]*Upload, error) { - args := m.Called(ctx, req, filenamePrefix) - - return args.Get(0).([]*Upload), args.Error(1) -} diff --git a/uploads/mock/doc.go b/uploads/mock/doc.go new file mode 100644 index 0000000..10e4fa0 --- /dev/null +++ b/uploads/mock/doc.go @@ -0,0 +1,9 @@ +// Package mockuploads provides mock implementations of the uploads package's +// interfaces. Both the hand-written testify-based MockUploadManager and the +// moq-generated UploadManagerMock live here during the testify → moq migration. +// New test code should prefer UploadManagerMock. +package mockuploads + +// Regenerate the moq mocks via `go generate ./uploads/mock/`. + +//go:generate go tool github.com/matryer/moq -out upload_manager_mock.go -pkg mockuploads -rm -fmt goimports .. UploadManager:UploadManagerMock diff --git a/uploads/mock/mock.go b/uploads/mock/mock.go deleted file mode 100644 index ea80292..0000000 --- a/uploads/mock/mock.go +++ /dev/null @@ -1,28 +0,0 @@ -package mockuploads - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/uploads" - - "github.com/stretchr/testify/mock" -) - -var _ uploads.UploadManager = (*MockUploadManager)(nil) - -// MockUploadManager is a mock MockUploadManager. -type MockUploadManager struct { - mock.Mock -} - -// SaveFile satisfies the MockUploadManager interface. -func (m *MockUploadManager) SaveFile(ctx context.Context, path string, content []byte) error { - return m.Called(ctx, path, content).Error(0) -} - -// ReadFile satisfies the MockUploadManager interface. -func (m *MockUploadManager) ReadFile(ctx context.Context, path string) ([]byte, error) { - args := m.Called(ctx, path) - - return args.Get(0).([]byte), args.Error(1) -} diff --git a/uploads/mock/upload_manager_mock.go b/uploads/mock/upload_manager_mock.go new file mode 100644 index 0000000..df87430 --- /dev/null +++ b/uploads/mock/upload_manager_mock.go @@ -0,0 +1,139 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package mockuploads + +import ( + "context" + "sync" + + "github.com/verygoodsoftwarenotvirus/platform/v5/uploads" +) + +// Ensure, that UploadManagerMock does implement uploads.UploadManager. +// If this is not the case, regenerate this file with moq. +var _ uploads.UploadManager = &UploadManagerMock{} + +// UploadManagerMock is a mock implementation of uploads.UploadManager. +// +// func TestSomethingThatUsesUploadManager(t *testing.T) { +// +// // make and configure a mocked uploads.UploadManager +// mockedUploadManager := &UploadManagerMock{ +// ReadFileFunc: func(ctx context.Context, path string) ([]byte, error) { +// panic("mock out the ReadFile method") +// }, +// SaveFileFunc: func(ctx context.Context, path string, content []byte) error { +// panic("mock out the SaveFile method") +// }, +// } +// +// // use mockedUploadManager in code that requires uploads.UploadManager +// // and then make assertions. +// +// } +type UploadManagerMock struct { + // ReadFileFunc mocks the ReadFile method. + ReadFileFunc func(ctx context.Context, path string) ([]byte, error) + + // SaveFileFunc mocks the SaveFile method. + SaveFileFunc func(ctx context.Context, path string, content []byte) error + + // calls tracks calls to the methods. + calls struct { + // ReadFile holds details about calls to the ReadFile method. + ReadFile []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Path is the path argument value. + Path string + } + // SaveFile holds details about calls to the SaveFile method. + SaveFile []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Path is the path argument value. + Path string + // Content is the content argument value. + Content []byte + } + } + lockReadFile sync.RWMutex + lockSaveFile sync.RWMutex +} + +// ReadFile calls ReadFileFunc. +func (mock *UploadManagerMock) ReadFile(ctx context.Context, path string) ([]byte, error) { + if mock.ReadFileFunc == nil { + panic("UploadManagerMock.ReadFileFunc: method is nil but UploadManager.ReadFile was just called") + } + callInfo := struct { + Ctx context.Context + Path string + }{ + Ctx: ctx, + Path: path, + } + mock.lockReadFile.Lock() + mock.calls.ReadFile = append(mock.calls.ReadFile, callInfo) + mock.lockReadFile.Unlock() + return mock.ReadFileFunc(ctx, path) +} + +// ReadFileCalls gets all the calls that were made to ReadFile. +// Check the length with: +// +// len(mockedUploadManager.ReadFileCalls()) +func (mock *UploadManagerMock) ReadFileCalls() []struct { + Ctx context.Context + Path string +} { + var calls []struct { + Ctx context.Context + Path string + } + mock.lockReadFile.RLock() + calls = mock.calls.ReadFile + mock.lockReadFile.RUnlock() + return calls +} + +// SaveFile calls SaveFileFunc. +func (mock *UploadManagerMock) SaveFile(ctx context.Context, path string, content []byte) error { + if mock.SaveFileFunc == nil { + panic("UploadManagerMock.SaveFileFunc: method is nil but UploadManager.SaveFile was just called") + } + callInfo := struct { + Ctx context.Context + Path string + Content []byte + }{ + Ctx: ctx, + Path: path, + Content: content, + } + mock.lockSaveFile.Lock() + mock.calls.SaveFile = append(mock.calls.SaveFile, callInfo) + mock.lockSaveFile.Unlock() + return mock.SaveFileFunc(ctx, path, content) +} + +// SaveFileCalls gets all the calls that were made to SaveFile. +// Check the length with: +// +// len(mockedUploadManager.SaveFileCalls()) +func (mock *UploadManagerMock) SaveFileCalls() []struct { + Ctx context.Context + Path string + Content []byte +} { + var calls []struct { + Ctx context.Context + Path string + Content []byte + } + mock.lockSaveFile.RLock() + calls = mock.calls.SaveFile + mock.lockSaveFile.RUnlock() + return calls +} diff --git a/uploads/objectstorage/files_test.go b/uploads/objectstorage/files_test.go index e183050..60edeb2 100644 --- a/uploads/objectstorage/files_test.go +++ b/uploads/objectstorage/files_test.go @@ -97,8 +97,9 @@ func TestUploader_ReadFile(T *testing.T) { ctx := t.Context() - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } saveCounter, readCounter, saveErrCounter, readErrCounter, latencyHist := noopUploaderMetrics(t) u := &Uploader{ @@ -128,9 +129,10 @@ func TestUploader_ReadFile(T *testing.T) { b := memblob.OpenBucket(&memblob.Options{}) require.NoError(t, b.WriteAll(ctx, exampleFilename, expectedContent, nil)) - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Succeeded").Return() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + SucceededFunc: func() {}, + } saveCounter, readCounter, saveErrCounter, readErrCounter, latencyHist := noopUploaderMetrics(t) u := &Uploader{ @@ -179,8 +181,9 @@ func TestUploader_SaveFile(T *testing.T) { ctx := t.Context() - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(true) + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return true }, + } saveCounter, readCounter, saveErrCounter, readErrCounter, latencyHist := noopUploaderMetrics(t) u := &Uploader{ @@ -203,9 +206,10 @@ func TestUploader_SaveFile(T *testing.T) { ctx := t.Context() - cb := &cbmock.MockCircuitBreaker{} - cb.On("CannotProceed").Return(false) - cb.On("Failed").Return() + cb := &cbmock.CircuitBreakerMock{ + CannotProceedFunc: func() bool { return false }, + FailedFunc: func() {}, + } b := memblob.OpenBucket(&memblob.Options{}) require.NoError(t, b.Close()) diff --git a/uploads/objectstorage/mock_uploader.go b/uploads/objectstorage/mock_uploader.go deleted file mode 100644 index 05b12f8..0000000 --- a/uploads/objectstorage/mock_uploader.go +++ /dev/null @@ -1,30 +0,0 @@ -package objectstorage - -import ( - "context" - - "github.com/verygoodsoftwarenotvirus/platform/v5/uploads" - - "github.com/stretchr/testify/mock" -) - -var _ uploads.UploadManager = (*MockUploader)(nil) - -type ( - // MockUploader is a mock uploads.UploadManager. - MockUploader struct { - mock.Mock - } -) - -// SaveFile is a mock function. -func (m *MockUploader) SaveFile(ctx context.Context, path string, content []byte) error { - return m.Called(ctx, path, content).Error(0) -} - -// ReadFile is a mock function. -func (m *MockUploader) ReadFile(ctx context.Context, path string) ([]byte, error) { - returnValues := m.Called(ctx, path) - - return returnValues.Get(0).([]byte), returnValues.Error(1) -} diff --git a/uploads/objectstorage/mock_uploader_test.go b/uploads/objectstorage/mock_uploader_test.go deleted file mode 100644 index 3a1c479..0000000 --- a/uploads/objectstorage/mock_uploader_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package objectstorage - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestMockUploader_SaveFile(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - m := &MockUploader{} - m.On("SaveFile", mock.Anything, "test.txt", []byte("content")).Return(nil) - - assert.NoError(t, m.SaveFile(ctx, "test.txt", []byte("content"))) - mock.AssertExpectationsForObjects(t, m) - }) -} - -func TestMockUploader_ReadFile(T *testing.T) { - T.Parallel() - - T.Run("standard", func(t *testing.T) { - t.Parallel() - - ctx := t.Context() - expected := []byte("content") - m := &MockUploader{} - m.On("ReadFile", mock.Anything, "test.txt").Return(expected, nil) - - actual, err := m.ReadFile(ctx, "test.txt") - assert.NoError(t, err) - assert.Equal(t, expected, actual) - mock.AssertExpectationsForObjects(t, m) - }) -} From 801a7c3f24ce6acc5d88a5c9ea94871489202c76 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Sat, 11 Apr 2026 16:46:20 -0500 Subject: [PATCH 07/12] test: switch from stretchr/testify to shoenig/test --- analytics/config/config_test.go | 61 ++-- analytics/config/do_test.go | 8 +- analytics/config/wire_test.go | 10 +- analytics/multisource/config_test.go | 40 +-- analytics/multisource/do_test.go | 8 +- analytics/multisource/reporter_test.go | 49 ++- analytics/noop/noop_test.go | 14 +- analytics/posthog/config_test.go | 6 +- analytics/posthog/posthog_test.go | 40 +-- analytics/rudderstack/config_test.go | 8 +- analytics/rudderstack/rudderstack_test.go | 48 +-- analytics/segment/config_test.go | 6 +- analytics/segment/segment_test.go | 40 +-- bitmask/bitmask_test.go | 184 +++++----- capitalism/config/config_test.go | 18 +- capitalism/config/do_test.go | 8 +- capitalism/noop/noop_test.go | 8 +- capitalism/stripe/config_test.go | 6 +- capitalism/stripe/stripe_test.go | 61 ++-- circuitbreaking/config/config_test.go | 99 +++--- circuitbreaking/config/do_test.go | 8 +- circuitbreaking/noop/noop_test.go | 8 +- compression/compressor_test.go | 64 ++-- cryptography/encryption/aes/aes_test.go | 18 +- cryptography/encryption/config/config_test.go | 26 +- cryptography/encryption/config/do_test.go | 8 +- cryptography/encryption/errors_test.go | 6 +- .../encryption/salsa20/salsa20_test.go | 24 +- cryptography/hashing/adler32/adler32_test.go | 6 +- cryptography/hashing/crc64/crc64_test.go | 6 +- cryptography/hashing/fnv/fnv_test.go | 6 +- cryptography/hashing/sha256/sha256_test.go | 6 +- cryptography/hashing/sha512/sha512_test.go | 6 +- database/config/config_test.go | 189 +++++----- database/config/do_test.go | 18 +- database/errors_test.go | 10 +- database/filtering/query_filter_test.go | 58 +-- database/mysql/do_test.go | 8 +- database/mysql/mysql_test.go | 72 ++-- .../mysql/tableaccess/access_manager_test.go | 72 ++-- database/null_values_test.go | 112 +++--- database/postgres/do_test.go | 8 +- database/postgres/postgres_test.go | 72 ++-- .../tableaccess/access_manager_test.go | 334 +++++++++--------- database/sqlite/do_test.go | 8 +- database/sqlite/sqlite_test.go | 72 ++-- .../sqlite/tableaccess/access_manager_test.go | 24 +- distributedlock/config/config_test.go | 59 ++-- distributedlock/memory/memory_test.go | 80 ++--- distributedlock/noop/noop_test.go | 34 +- distributedlock/postgres/config_test.go | 6 +- distributedlock/postgres/postgres_test.go | 218 ++++++------ distributedlock/redis/config_test.go | 12 +- distributedlock/redis/redis_test.go | 218 ++++++------ email/config/config_test.go | 61 ++-- email/config/do_test.go | 8 +- email/mailgun/mailgun_test.go | 40 +-- email/mailjet/mailjet_test.go | 34 +- email/noop/noop_test.go | 16 +- email/postmark/postmark_test.go | 36 +- email/resend/resend_test.go | 42 +-- email/sendgrid/sendgrid_test.go | 30 +- email/ses/config_test.go | 10 +- email/ses/ses_test.go | 46 +-- embeddings/cohere/cohere_test.go | 116 +++--- embeddings/config/config_test.go | 36 +- embeddings/config/do_test.go | 8 +- embeddings/embeddings_test.go | 20 +- embeddings/ollama/ollama_test.go | 108 +++--- embeddings/openai/openai_test.go | 116 +++--- encoding/client_encoder_test.go | 40 +-- encoding/config_test.go | 6 +- encoding/content_type_test.go | 26 +- encoding/do_test.go | 12 +- encoding/providers_test.go | 4 +- encoding/server_encoder_decoder_test.go | 70 ++-- encoding/utils_test.go | 30 +- errors/errors_test.go | 58 +-- errors/grpc/grpc_interceptor_test.go | 80 ++--- errors/grpc/grpc_test.go | 56 +-- errors/http/map_test.go | 72 ++-- eventstream/config/config_test.go | 30 +- eventstream/config/do_test.go | 12 +- eventstream/manager_test.go | 93 ++--- eventstream/sse/sse_test.go | 132 +++---- eventstream/websocket/config_test.go | 4 +- eventstream/websocket/websocket_test.go | 78 ++-- fake/fake.go | 4 +- fake/fake_test.go | 22 +- featureflags/config/config_test.go | 43 ++- featureflags/config/do_test.go | 8 +- .../launchdarkly/feature_flag_manager_test.go | 108 +++--- featureflags/noop/noop_test.go | 28 +- .../posthog/feature_flag_manager_test.go | 108 +++--- go.mod | 2 +- healthcheck/checkers_test.go | 41 ++- healthcheck/healthcheck_test.go | 36 +- httpclient/client_test.go | 32 +- httpclient/config_test.go | 24 +- identifiers/identifiers_test.go | 6 +- llm/anthropic/anthropic_test.go | 62 ++-- llm/anthropic/config_test.go | 6 +- llm/config/config_test.go | 43 ++- llm/config/do_test.go | 8 +- llm/llm_test.go | 6 +- llm/openai/config_test.go | 6 +- llm/openai/openai_test.go | 62 ++-- messagequeue/config/config_test.go | 66 ++-- messagequeue/config/do_test.go | 12 +- messagequeue/kafka/config_test.go | 8 +- messagequeue/kafka/consumer_test.go | 55 ++- messagequeue/kafka/publisher_test.go | 89 +++-- messagequeue/noop/noop_test.go | 16 +- messagequeue/publishers_test.go | 6 +- messagequeue/pubsub/config_test.go | 4 +- messagequeue/pubsub/consumer_test.go | 84 +++-- messagequeue/pubsub/publisher_test.go | 24 +- messagequeue/redis/config_test.go | 4 +- messagequeue/redis/consumer_test.go | 46 +-- messagequeue/redis/publisher_test.go | 86 ++--- messagequeue/sqs/consumer_test.go | 62 ++-- messagequeue/sqs/publisher_test.go | 86 ++--- notifications/async/ably/ably_test.go | 22 +- notifications/async/ably/config_test.go | 6 +- notifications/async/config/config_test.go | 46 +-- notifications/async/config/do_test.go | 8 +- notifications/async/noop/noop_test.go | 16 +- notifications/async/pusher/config_test.go | 6 +- notifications/async/pusher/pusher_test.go | 22 +- notifications/async/sse/config_test.go | 4 +- notifications/async/sse/sse_test.go | 34 +- notifications/async/websocket/config_test.go | 4 +- .../async/websocket/websocket_test.go | 38 +- notifications/mobile/apns/apns_sender_test.go | 99 +++--- notifications/mobile/config/config_test.go | 74 ++-- notifications/mobile/config/do_test.go | 12 +- notifications/mobile/fcm/fcm_sender_test.go | 67 ++-- .../mobile/multi_platform_push_sender_test.go | 37 +- numbers/numbers_test.go | 74 ++-- observability/config_test.go | 18 +- observability/do_test.go | 20 +- observability/errors_test.go | 32 +- observability/helpers_test.go | 12 +- observability/logging/config/config_test.go | 34 +- observability/logging/config/do_test.go | 8 +- observability/logging/logging_test.go | 76 ++-- .../logging/otelgrpc/slog_logger_test.go | 86 ++--- .../logging/slog/slog_logger_test.go | 36 +- observability/logging/zap/zap_logger_test.go | 46 +-- .../logging/zerolog/zerolog_logger_test.go | 40 +-- observability/metrics/config/config_test.go | 30 +- observability/metrics/config/do_test.go | 8 +- observability/metrics/metrics_test.go | 62 ++-- .../metrics/otelgrpc/config/config_test.go | 18 +- .../metrics/otelgrpc/config/do_test.go | 8 +- .../metrics/otelgrpc/metrics_test.go | 139 ++++---- observability/metrics/testing.go | 4 +- observability/profiling/config/config_test.go | 58 +-- observability/profiling/config/do_test.go | 8 +- observability/profiling/noop_test.go | 8 +- observability/profiling/pprof/config_test.go | 6 +- .../profiling/pprof/provider_test.go | 26 +- .../profiling/pyroscope/config_test.go | 8 +- .../profiling/pyroscope/provider_test.go | 36 +- observability/tracing/caller_test.go | 4 +- .../tracing/cloudtrace/config_test.go | 4 +- observability/tracing/config/config_test.go | 50 +-- observability/tracing/config/do_test.go | 8 +- observability/tracing/instrumentedsql_test.go | 6 +- .../tracing/oteltrace/config_test.go | 4 +- .../tracing/oteltrace/tracer_test.go | 6 +- observability/tracing/span_attachers_test.go | 4 +- observability/tracing/span_manager_test.go | 8 +- observability/tracing/spans_test.go | 4 +- observability/utils/otel_test.go | 6 +- panicking/panicker_test.go | 8 +- pointer/pointers_test.go | 76 ++-- qrcodes/do_test.go | 8 +- qrcodes/qrcodes_test.go | 22 +- random/do_test.go | 8 +- random/random_test.go | 58 +-- random/slices_test.go | 4 +- ratelimiting/config/config_test.go | 78 ++-- ratelimiting/config/do_test.go | 8 +- ratelimiting/noop/noop_test.go | 10 +- ratelimiting/ratelimiting_test.go | 44 +-- ratelimiting/redis/redis_test.go | 63 ++-- reflection/ast/helpers_test.go | 60 ++-- reflection/utils_test.go | 126 +++---- retry/config_test.go | 28 +- retry/noop/noop_test.go | 14 +- retry/retry_test.go | 22 +- routing/chi/config_test.go | 4 +- routing/chi/middleware_test.go | 4 +- routing/chi/request_id_test.go | 8 +- routing/chi/routeparams_test.go | 10 +- routing/chi/router_test.go | 48 +-- routing/config/config_test.go | 40 +-- routing/config/do_test.go | 8 +- search/text/algolia/algolia_test.go | 10 +- search/text/algolia/config_test.go | 14 +- search/text/algolia/index_test.go | 84 ++--- search/text/config/config_test.go | 63 ++-- search/text/elasticsearch/config_test.go | 22 +- .../text/elasticsearch/elasticsearch_test.go | 136 +++---- search/text/elasticsearch/index_test.go | 68 ++-- search/text/indexing/do_test.go | 7 +- search/text/indexing/indexer_test.go | 71 ++-- search/text/noop/noop_test.go | 16 +- search/vector/config/config_test.go | 47 ++- search/vector/noop/noop_test.go | 16 +- search/vector/pgvector/pgvector_test.go | 196 +++++----- search/vector/qdrant/qdrant_test.go | 312 ++++++++-------- secrets/config/config_test.go | 127 ++++--- secrets/config/do_test.go | 8 +- secrets/config/wire_test.go | 46 +-- secrets/env/env_test.go | 33 +- secrets/gcp/config_test.go | 6 +- secrets/gcp/gcp_test.go | 55 ++- secrets/kubectl/config_test.go | 12 +- secrets/kubectl/kubectl_test.go | 77 ++-- secrets/noop/noop_test.go | 10 +- secrets/ssm/config_test.go | 6 +- secrets/ssm/ssm_test.go | 63 ++-- server/grpc/do_test.go | 8 +- server/grpc/server_test.go | 90 ++--- server/http/config_test.go | 10 +- server/http/do_test.go | 8 +- server/http/http_server_test.go | 84 ++--- server/http/static_files_test.go | 22 +- testutils/testutil.go | 8 +- testutils/testutil_test.go | 30 +- types/main_test.go | 40 +-- uploads/config/config_test.go | 6 +- uploads/config/do_test.go | 10 +- uploads/images/images_test.go | 128 +++---- uploads/images/thumbnails_test.go | 68 ++-- .../objectstorage/bucket_backblaze_test.go | 12 +- .../objectstorage/bucket_filesystem_test.go | 6 +- uploads/objectstorage/bucket_gcp_test.go | 6 +- uploads/objectstorage/bucket_r2_test.go | 12 +- uploads/objectstorage/bucket_s3_test.go | 6 +- uploads/objectstorage/do_test.go | 12 +- uploads/objectstorage/files_test.go | 48 +-- uploads/objectstorage/providers_test.go | 6 +- uploads/objectstorage/uploader_test.go | 82 ++--- version/version_test.go | 18 +- 247 files changed, 4962 insertions(+), 4953 deletions(-) diff --git a/analytics/config/config_test.go b/analytics/config/config_test.go index fe45a89..d86b22b 100644 --- a/analytics/config/config_test.go +++ b/analytics/config/config_test.go @@ -14,8 +14,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -33,7 +32,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - require.NoError(t, cfg.ValidateWithContext(ctx)) + must.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with invalid token", func(t *testing.T) { @@ -46,7 +45,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - require.Error(t, cfg.ValidateWithContext(ctx)) + must.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -76,7 +75,7 @@ func TestConfig_ProvideCollector(T *testing.T) { } _, err := cfg.ProvideCollector(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) + must.NoError(t, err) } }) @@ -96,7 +95,7 @@ func TestConfig_ProvideCollector(T *testing.T) { } _, err := cfg.ProvideCollector(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.Error(t, err) + must.Error(t, err) } }) @@ -111,8 +110,8 @@ func TestConfig_ProvideCollector(T *testing.T) { } reporter, err := cfg.ProvideCollector(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - assert.Nil(t, reporter) - assert.Error(t, err) + test.Nil(t, reporter) + test.Error(t, err) }) T.Run("with rudderstack provider but nil rudderstack config", func(t *testing.T) { @@ -126,8 +125,8 @@ func TestConfig_ProvideCollector(T *testing.T) { } reporter, err := cfg.ProvideCollector(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - assert.Nil(t, reporter) - assert.Error(t, err) + test.Nil(t, reporter) + test.Error(t, err) }) T.Run("with posthog provider but nil posthog config", func(t *testing.T) { @@ -141,8 +140,8 @@ func TestConfig_ProvideCollector(T *testing.T) { } reporter, err := cfg.ProvideCollector(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - assert.Nil(t, reporter) - assert.Error(t, err) + test.Nil(t, reporter) + test.Error(t, err) }) T.Run("with unrecognized provider returns noop", func(t *testing.T) { @@ -156,8 +155,8 @@ func TestConfig_ProvideCollector(T *testing.T) { } reporter, err := cfg.ProvideCollector(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - assert.NotNil(t, reporter) - assert.NoError(t, err) + test.NotNil(t, reporter) + test.NoError(t, err) }) T.Run("with circuit breaker error", func(t *testing.T) { @@ -184,8 +183,8 @@ func TestConfig_ProvideCollector(T *testing.T) { } reporter, err := cfg.ProvideCollector(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp) - assert.Nil(t, reporter) - assert.Error(t, err) + test.Nil(t, reporter) + test.Error(t, err) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -200,9 +199,9 @@ func TestSourceConfig_EnsureDefaults(T *testing.T) { cfg := &SourceConfig{} cfg.EnsureDefaults() - assert.NotEmpty(t, cfg.CircuitBreaker.Name) - assert.NotZero(t, cfg.CircuitBreaker.ErrorRate) - assert.NotZero(t, cfg.CircuitBreaker.MinimumSampleThreshold) + test.NotEq(t, "", cfg.CircuitBreaker.Name) + test.NotEq(t, float64(0), cfg.CircuitBreaker.ErrorRate) + test.NotEq(t, uint64(0), cfg.CircuitBreaker.MinimumSampleThreshold) }) } @@ -215,7 +214,7 @@ func TestConfig_EnsureDefaults(T *testing.T) { cfg := &Config{} cfg.EnsureDefaults() - assert.NotEmpty(t, cfg.CircuitBreaker.Name) + test.NotEq(t, "", cfg.CircuitBreaker.Name) }) T.Run("with both proxy sources set", func(t *testing.T) { @@ -229,9 +228,9 @@ func TestConfig_EnsureDefaults(T *testing.T) { } cfg.EnsureDefaults() - assert.NotEmpty(t, cfg.CircuitBreaker.Name) - assert.NotEmpty(t, cfg.ProxySources.IOS.CircuitBreaker.Name) - assert.NotEmpty(t, cfg.ProxySources.Web.CircuitBreaker.Name) + test.NotEq(t, "", cfg.CircuitBreaker.Name) + test.NotEq(t, "", cfg.ProxySources.IOS.CircuitBreaker.Name) + test.NotEq(t, "", cfg.ProxySources.Web.CircuitBreaker.Name) }) } @@ -242,7 +241,7 @@ func TestProxySourcesConfig_ToMap(T *testing.T) { t.Parallel() p := ProxySourcesConfig{} - assert.Empty(t, p.ToMap()) + test.MapEmpty(t, p.ToMap()) }) T.Run("with only ios set", func(t *testing.T) { @@ -252,8 +251,8 @@ func TestProxySourcesConfig_ToMap(T *testing.T) { p := ProxySourcesConfig{IOS: ios} m := p.ToMap() - assert.Len(t, m, 1) - assert.Same(t, ios, m["ios"]) + test.MapLen(t, 1, m) + test.EqOp(t, ios, m["ios"]) }) T.Run("with only web set", func(t *testing.T) { @@ -263,8 +262,8 @@ func TestProxySourcesConfig_ToMap(T *testing.T) { p := ProxySourcesConfig{Web: web} m := p.ToMap() - assert.Len(t, m, 1) - assert.Same(t, web, m["web"]) + test.MapLen(t, 1, m) + test.EqOp(t, web, m["web"]) }) T.Run("with both sources set", func(t *testing.T) { @@ -275,8 +274,8 @@ func TestProxySourcesConfig_ToMap(T *testing.T) { p := ProxySourcesConfig{IOS: ios, Web: web} m := p.ToMap() - assert.Len(t, m, 2) - assert.Same(t, ios, m["ios"]) - assert.Same(t, web, m["web"]) + test.MapLen(t, 2, m) + test.EqOp(t, ios, m["ios"]) + test.EqOp(t, web, m["web"]) }) } diff --git a/analytics/config/do_test.go b/analytics/config/do_test.go index 88d6dc7..37a95ce 100644 --- a/analytics/config/do_test.go +++ b/analytics/config/do_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterEventReporter(T *testing.T) { @@ -35,7 +35,7 @@ func TestRegisterEventReporter(T *testing.T) { RegisterEventReporter(i) reporter, err := do.Invoke[analytics.EventReporter](i) - require.NoError(t, err) - assert.NotNil(t, reporter) + must.NoError(t, err) + test.NotNil(t, reporter) }) } diff --git a/analytics/config/wire_test.go b/analytics/config/wire_test.go index 96ff1c0..024b4a2 100644 --- a/analytics/config/wire_test.go +++ b/analytics/config/wire_test.go @@ -8,7 +8,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestProvideCollector(T *testing.T) { @@ -22,8 +22,8 @@ func TestProvideCollector(T *testing.T) { logger := logging.NewNoopLogger() actual, err := ProvideEventReporter(ctx, cfg, logger, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, actual) + must.NoError(t, err) + must.NotNil(t, actual) }) T.Run("with segment", func(t *testing.T) { @@ -41,7 +41,7 @@ func TestProvideCollector(T *testing.T) { logger := logging.NewNoopLogger() actual, err := ProvideEventReporter(ctx, cfg, logger, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, actual) + must.NoError(t, err) + must.NotNil(t, actual) }) } diff --git a/analytics/multisource/config_test.go b/analytics/multisource/config_test.go index 4561f1b..9f3003a 100644 --- a/analytics/multisource/config_test.go +++ b/analytics/multisource/config_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestProvideMultiSourceEventReporter(T *testing.T) { @@ -23,9 +23,9 @@ func TestProvideMultiSourceEventReporter(T *testing.T) { ctx := t.Context() reporter, err := ProvideMultiSourceEventReporter(ctx, nil, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, reporter) - assert.Empty(t, reporter.reporters) + must.NoError(t, err) + must.NotNil(t, reporter) + test.MapEmpty(t, reporter.reporters) }) T.Run("with valid segment source", func(t *testing.T) { @@ -40,9 +40,9 @@ func TestProvideMultiSourceEventReporter(T *testing.T) { } reporter, err := ProvideMultiSourceEventReporter(ctx, sources, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, reporter) - assert.Len(t, reporter.reporters, 1) + must.NoError(t, err) + must.NotNil(t, reporter) + test.MapLen(t, 1, reporter.reporters) }) T.Run("with invalid source falls back to noop", func(t *testing.T) { @@ -57,9 +57,9 @@ func TestProvideMultiSourceEventReporter(T *testing.T) { } reporter, err := ProvideMultiSourceEventReporter(ctx, sources, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, reporter) - assert.Len(t, reporter.reporters, 1) + must.NoError(t, err) + must.NotNil(t, reporter) + test.MapLen(t, 1, reporter.reporters) }) T.Run("with unrecognized provider uses noop", func(t *testing.T) { @@ -73,9 +73,9 @@ func TestProvideMultiSourceEventReporter(T *testing.T) { } reporter, err := ProvideMultiSourceEventReporter(ctx, sources, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, reporter) - assert.Len(t, reporter.reporters, 1) + must.NoError(t, err) + must.NotNil(t, reporter) + test.MapLen(t, 1, reporter.reporters) }) T.Run("with multiple posthog sources reuses shared reporter", func(t *testing.T) { @@ -94,9 +94,9 @@ func TestProvideMultiSourceEventReporter(T *testing.T) { } reporter, err := ProvideMultiSourceEventReporter(ctx, sources, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, reporter) - assert.Len(t, reporter.reporters, 2) + must.NoError(t, err) + must.NotNil(t, reporter) + test.MapLen(t, 2, reporter.reporters) }) T.Run("with empty proxy sources map", func(t *testing.T) { @@ -106,8 +106,8 @@ func TestProvideMultiSourceEventReporter(T *testing.T) { sources := map[string]*analyticscfg.SourceConfig{} reporter, err := ProvideMultiSourceEventReporter(ctx, sources, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, reporter) - assert.Empty(t, reporter.reporters) + must.NoError(t, err) + must.NotNil(t, reporter) + test.MapEmpty(t, reporter.reporters) }) } diff --git a/analytics/multisource/do_test.go b/analytics/multisource/do_test.go index 30f2a0a..0dffa03 100644 --- a/analytics/multisource/do_test.go +++ b/analytics/multisource/do_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterMultiSourceEventReporter(T *testing.T) { @@ -35,7 +35,7 @@ func TestRegisterMultiSourceEventReporter(T *testing.T) { RegisterMultiSourceEventReporter(i) reporter, err := do.Invoke[*MultiSourceEventReporter](i) - require.NoError(t, err) - assert.NotNil(t, reporter) + must.NoError(t, err) + test.NotNil(t, reporter) }) } diff --git a/analytics/multisource/reporter_test.go b/analytics/multisource/reporter_test.go index 28acfe2..d2f4ca3 100644 --- a/analytics/multisource/reporter_test.go +++ b/analytics/multisource/reporter_test.go @@ -9,8 +9,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/analytics/noop" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestNewMultiSourceEventReporter(T *testing.T) { @@ -20,8 +19,8 @@ func TestNewMultiSourceEventReporter(T *testing.T) { t.Parallel() r := NewMultiSourceEventReporter(nil, nil, nil) - require.NotNil(t, r) - assert.NotNil(t, r.reporters) + must.NotNil(t, r) + test.NotNil(t, r.reporters) }) T.Run("with populated reporters map", func(t *testing.T) { @@ -31,8 +30,8 @@ func TestNewMultiSourceEventReporter(T *testing.T) { "ios": noop.NewEventReporter(), } r := NewMultiSourceEventReporter(reporters, nil, nil) - require.NotNil(t, r) - assert.Len(t, r.reporters, 1) + must.NotNil(t, r) + test.MapLen(t, 1, r.reporters) }) } @@ -49,7 +48,7 @@ func TestMultiSourceEventReporter_getReporter(T *testing.T) { m := NewMultiSourceEventReporter(reporters, nil, nil) got := m.getReporter("ios") - assert.Equal(t, expected, got) + test.Eq(t, expected, got) }) T.Run("returns noop for unknown source", func(t *testing.T) { @@ -58,7 +57,7 @@ func TestMultiSourceEventReporter_getReporter(T *testing.T) { m := NewMultiSourceEventReporter(nil, nil, nil) got := m.getReporter("unknown") - assert.NotNil(t, got) + test.NotNil(t, got) }) T.Run("returns noop when reporter is nil in map", func(t *testing.T) { @@ -70,7 +69,7 @@ func TestMultiSourceEventReporter_getReporter(T *testing.T) { m := NewMultiSourceEventReporter(reporters, nil, nil) got := m.getReporter("ios") - assert.NotNil(t, got) + test.NotNil(t, got) }) } @@ -82,10 +81,10 @@ func TestMultiSourceEventReporter_TrackEvent(T *testing.T) { mockReporter := &analyticsmock.EventReporterMock{ EventOccurredFunc: func(_ context.Context, event, userID string, properties map[string]any) error { - assert.Equal(t, "signup", event) - assert.Equal(t, "user1", userID) - assert.Equal(t, "ios", properties[SourcePropertyKey]) - assert.Equal(t, "pro", properties["plan"]) + test.EqOp(t, "signup", event) + test.EqOp(t, "user1", userID) + test.Eq(t, "ios", properties[SourcePropertyKey]) + test.Eq(t, "pro", properties["plan"]) return nil }, } @@ -96,7 +95,7 @@ func TestMultiSourceEventReporter_TrackEvent(T *testing.T) { m := NewMultiSourceEventReporter(reporters, nil, nil) err := m.TrackEvent(context.Background(), "ios", "signup", "user1", map[string]any{"plan": "pro"}) - assert.NoError(t, err) + test.NoError(t, err) test.SliceLen(t, 1, mockReporter.EventOccurredCalls()) }) @@ -107,7 +106,7 @@ func TestMultiSourceEventReporter_TrackEvent(T *testing.T) { m := NewMultiSourceEventReporter(nil, nil, nil) err := m.TrackEvent(context.Background(), "unknown", "signup", "user1", nil) - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -119,9 +118,9 @@ func TestMultiSourceEventReporter_TrackAnonymousEvent(T *testing.T) { mockReporter := &analyticsmock.EventReporterMock{ EventOccurredAnonymousFunc: func(_ context.Context, event, anonymousID string, properties map[string]any) error { - assert.Equal(t, "page_view", event) - assert.Equal(t, "anon1", anonymousID) - assert.Equal(t, "web", properties[SourcePropertyKey]) + test.EqOp(t, "page_view", event) + test.EqOp(t, "anon1", anonymousID) + test.Eq(t, "web", properties[SourcePropertyKey]) return nil }, } @@ -132,7 +131,7 @@ func TestMultiSourceEventReporter_TrackAnonymousEvent(T *testing.T) { m := NewMultiSourceEventReporter(reporters, nil, nil) err := m.TrackAnonymousEvent(context.Background(), "web", "page_view", "anon1", map[string]any{}) - assert.NoError(t, err) + test.NoError(t, err) test.SliceLen(t, 1, mockReporter.EventOccurredAnonymousCalls()) }) @@ -145,8 +144,8 @@ func Test_withSourceProperty(T *testing.T) { t.Parallel() result := withSourceProperty("ios", nil) - assert.Equal(t, "ios", result[SourcePropertyKey]) - assert.Len(t, result, 1) + test.Eq(t, "ios", result[SourcePropertyKey]) + test.MapLen(t, 1, result) }) T.Run("adds source to existing properties without mutation", func(t *testing.T) { @@ -155,12 +154,12 @@ func Test_withSourceProperty(T *testing.T) { original := map[string]any{"key": "value"} result := withSourceProperty("web", original) - assert.Equal(t, "web", result[SourcePropertyKey]) - assert.Equal(t, "value", result["key"]) - assert.Len(t, result, 2) + test.Eq(t, "web", result[SourcePropertyKey]) + test.Eq(t, "value", result["key"]) + test.MapLen(t, 2, result) // original should not be mutated _, exists := original[SourcePropertyKey] - assert.False(t, exists) + test.False(t, exists) }) } diff --git a/analytics/noop/noop_test.go b/analytics/noop/noop_test.go index 912c6ae..e8ef57a 100644 --- a/analytics/noop/noop_test.go +++ b/analytics/noop/noop_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNewEventReporter(T *testing.T) { @@ -15,7 +15,7 @@ func TestNewEventReporter(T *testing.T) { t.Parallel() r := NewEventReporter() - require.NotNil(t, r) + must.NotNil(t, r) }) } @@ -26,7 +26,7 @@ func TestEventReporter_Close(T *testing.T) { t.Parallel() r := NewEventReporter() - assert.NotPanics(t, func() { + test.NotPanic(t, func() { r.Close() }) }) @@ -40,7 +40,7 @@ func TestEventReporter_AddUser(T *testing.T) { r := NewEventReporter() err := r.AddUser(context.Background(), "user123", map[string]any{"key": "value"}) - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -52,7 +52,7 @@ func TestEventReporter_EventOccurred(T *testing.T) { r := NewEventReporter() err := r.EventOccurred(context.Background(), "event_name", "user123", map[string]any{"key": "value"}) - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -64,6 +64,6 @@ func TestEventReporter_EventOccurredAnonymous(T *testing.T) { r := NewEventReporter() err := r.EventOccurredAnonymous(context.Background(), "event_name", "anon123", map[string]any{"key": "value"}) - assert.NoError(t, err) + test.NoError(t, err) }) } diff --git a/analytics/posthog/config_test.go b/analytics/posthog/config_test.go index 1cd3f5c..2ccdc90 100644 --- a/analytics/posthog/config_test.go +++ b/analytics/posthog/config_test.go @@ -3,7 +3,7 @@ package posthog import ( "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -14,7 +14,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{APIKey: t.Name()} - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("with empty API key", func(t *testing.T) { @@ -22,6 +22,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/analytics/posthog/posthog_test.go b/analytics/posthog/posthog_test.go index 85c28f8..4726efb 100644 --- a/analytics/posthog/posthog_test.go +++ b/analytics/posthog/posthog_test.go @@ -12,7 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -26,8 +26,8 @@ func TestNewPostHogEventReporter(T *testing.T) { cfg := &Config{APIKey: t.Name()} collector, err := NewPostHogEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg.APIKey, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) }) T.Run("with empty API key", func(t *testing.T) { @@ -37,8 +37,8 @@ func TestNewPostHogEventReporter(T *testing.T) { cfg := &Config{} collector, err := NewPostHogEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg.APIKey, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) }) T.Run("with error creating event counter", func(t *testing.T) { @@ -52,8 +52,8 @@ func TestNewPostHogEventReporter(T *testing.T) { } collector, err := NewPostHogEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, t.Name(), cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -75,8 +75,8 @@ func TestNewPostHogEventReporter(T *testing.T) { } collector, err := NewPostHogEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, t.Name(), cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -92,8 +92,8 @@ func TestPostHogEventReporter_Close(T *testing.T) { cfg := &Config{APIKey: t.Name()} collector, err := NewPostHogEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg.APIKey, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) collector.Close() }) @@ -114,10 +114,10 @@ func TestPostHogEventReporter_AddUser(T *testing.T) { } collector, err := NewPostHogEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg.APIKey, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.AddUser(ctx, exampleUserID, properties)) + must.NoError(t, collector.AddUser(ctx, exampleUserID, properties)) }) } @@ -136,10 +136,10 @@ func TestPostHogEventReporter_EventOccurred(T *testing.T) { } collector, err := NewPostHogEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg.APIKey, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.EventOccurred(ctx, t.Name(), exampleUserID, properties)) + must.NoError(t, collector.EventOccurred(ctx, t.Name(), exampleUserID, properties)) }) } @@ -158,9 +158,9 @@ func TestPostHogEventReporter_EventOccurredAnonymous(T *testing.T) { } collector, err := NewPostHogEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg.APIKey, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.EventOccurredAnonymous(ctx, t.Name(), exampleAnonymousID, properties)) + must.NoError(t, collector.EventOccurredAnonymous(ctx, t.Name(), exampleAnonymousID, properties)) }) } diff --git a/analytics/rudderstack/config_test.go b/analytics/rudderstack/config_test.go index 01807a8..53f0687 100644 --- a/analytics/rudderstack/config_test.go +++ b/analytics/rudderstack/config_test.go @@ -3,7 +3,7 @@ package rudderstack import ( "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -17,7 +17,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { DataPlaneURL: t.Name(), } - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("with empty API key", func(t *testing.T) { @@ -27,7 +27,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { DataPlaneURL: t.Name(), } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("with empty data plane URL", func(t *testing.T) { @@ -37,6 +37,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { APIKey: t.Name(), } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/analytics/rudderstack/rudderstack_test.go b/analytics/rudderstack/rudderstack_test.go index 5436fc1..096b92f 100644 --- a/analytics/rudderstack/rudderstack_test.go +++ b/analytics/rudderstack/rudderstack_test.go @@ -12,7 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -29,8 +29,8 @@ func TestNewRudderstackEventReporter(T *testing.T) { } collector, err := NewRudderstackEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) }) T.Run("with nil config", func(t *testing.T) { @@ -39,8 +39,8 @@ func TestNewRudderstackEventReporter(T *testing.T) { logger := logging.NewNoopLogger() collector, err := NewRudderstackEventReporter(logger, tracing.NewNoopTracerProvider(), nil, nil, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) }) T.Run("with empty API key", func(t *testing.T) { @@ -53,8 +53,8 @@ func TestNewRudderstackEventReporter(T *testing.T) { } collector, err := NewRudderstackEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) }) T.Run("with empty DataPlane URL", func(t *testing.T) { @@ -67,8 +67,8 @@ func TestNewRudderstackEventReporter(T *testing.T) { } collector, err := NewRudderstackEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) }) T.Run("with error creating event counter", func(t *testing.T) { @@ -87,8 +87,8 @@ func TestNewRudderstackEventReporter(T *testing.T) { } collector, err := NewRudderstackEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, cfg, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -115,8 +115,8 @@ func TestNewRudderstackEventReporter(T *testing.T) { } collector, err := NewRudderstackEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, cfg, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -135,8 +135,8 @@ func TestRudderstackEventReporter_Close(T *testing.T) { } collector, err := NewRudderstackEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) collector.Close() }) @@ -161,10 +161,10 @@ func TestRudderstackEventReporter_AddUser(T *testing.T) { } collector, err := NewRudderstackEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.AddUser(ctx, exampleUserID, properties)) + must.NoError(t, collector.AddUser(ctx, exampleUserID, properties)) }) } @@ -187,10 +187,10 @@ func TestRudderstackEventReporter_EventOccurred(T *testing.T) { } collector, err := NewRudderstackEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.EventOccurred(ctx, t.Name(), exampleUserID, properties)) + must.NoError(t, collector.EventOccurred(ctx, t.Name(), exampleUserID, properties)) }) } @@ -213,9 +213,9 @@ func TestRudderstackEventReporter_EventOccurredAnonymous(T *testing.T) { } collector, err := NewRudderstackEventReporter(logger, tracing.NewNoopTracerProvider(), nil, cfg, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.EventOccurredAnonymous(ctx, t.Name(), exampleAnonymousID, properties)) + must.NoError(t, collector.EventOccurredAnonymous(ctx, t.Name(), exampleAnonymousID, properties)) }) } diff --git a/analytics/segment/config_test.go b/analytics/segment/config_test.go index 993c7fa..242e718 100644 --- a/analytics/segment/config_test.go +++ b/analytics/segment/config_test.go @@ -3,7 +3,7 @@ package segment import ( "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -14,7 +14,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{APIToken: t.Name()} - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("with empty API token", func(t *testing.T) { @@ -22,6 +22,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/analytics/segment/segment_test.go b/analytics/segment/segment_test.go index b03d21a..56cf2c1 100644 --- a/analytics/segment/segment_test.go +++ b/analytics/segment/segment_test.go @@ -12,7 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -25,8 +25,8 @@ func TestNewSegmentEventReporter(T *testing.T) { logger := logging.NewNoopLogger() collector, err := NewSegmentEventReporter(logger, tracing.NewNoopTracerProvider(), nil, t.Name(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) }) T.Run("with empty API key", func(t *testing.T) { @@ -35,8 +35,8 @@ func TestNewSegmentEventReporter(T *testing.T) { logger := logging.NewNoopLogger() collector, err := NewSegmentEventReporter(logger, tracing.NewNoopTracerProvider(), nil, "", cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) }) T.Run("with error creating event counter", func(t *testing.T) { @@ -50,8 +50,8 @@ func TestNewSegmentEventReporter(T *testing.T) { } collector, err := NewSegmentEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, t.Name(), cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -73,8 +73,8 @@ func TestNewSegmentEventReporter(T *testing.T) { } collector, err := NewSegmentEventReporter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, t.Name(), cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, collector) + must.Error(t, err) + must.Nil(t, collector) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -89,8 +89,8 @@ func TestSegmentEventReporter_Close(T *testing.T) { logger := logging.NewNoopLogger() collector, err := NewSegmentEventReporter(logger, tracing.NewNoopTracerProvider(), nil, t.Name(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) collector.Close() }) @@ -110,10 +110,10 @@ func TestSegmentEventReporter_AddUser(T *testing.T) { } collector, err := NewSegmentEventReporter(logger, tracing.NewNoopTracerProvider(), nil, t.Name(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.AddUser(ctx, exampleUserID, properties)) + must.NoError(t, collector.AddUser(ctx, exampleUserID, properties)) }) } @@ -131,10 +131,10 @@ func TestSegmentEventReporter_EventOccurred(T *testing.T) { } collector, err := NewSegmentEventReporter(logger, tracing.NewNoopTracerProvider(), nil, t.Name(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.EventOccurred(ctx, t.Name(), exampleUserID, properties)) + must.NoError(t, collector.EventOccurred(ctx, t.Name(), exampleUserID, properties)) }) } @@ -152,9 +152,9 @@ func TestSegmentEventReporter_EventOccurredAnonymous(T *testing.T) { } collector, err := NewSegmentEventReporter(logger, tracing.NewNoopTracerProvider(), nil, t.Name(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, collector) + must.NoError(t, err) + must.NotNil(t, collector) - require.NoError(t, collector.EventOccurredAnonymous(ctx, t.Name(), exampleAnonymousID, properties)) + must.NoError(t, collector.EventOccurredAnonymous(ctx, t.Name(), exampleAnonymousID, properties)) }) } diff --git a/bitmask/bitmask_test.go b/bitmask/bitmask_test.go index 4245937..743c7b7 100644 --- a/bitmask/bitmask_test.go +++ b/bitmask/bitmask_test.go @@ -4,8 +4,8 @@ import ( "encoding/json" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type testPerm uint8 @@ -25,8 +25,8 @@ func TestNew(T *testing.T) { mask := New[testPerm]() - assert.Equal(t, testPerm(0), mask.Value()) - assert.True(t, mask.IsEmpty()) + test.EqOp(t, testPerm(0), mask.Value()) + test.True(t, mask.IsEmpty()) }) T.Run("with single flag", func(t *testing.T) { @@ -34,7 +34,7 @@ func TestNew(T *testing.T) { mask := New(permRead) - assert.Equal(t, permRead, mask.Value()) + test.EqOp(t, permRead, mask.Value()) }) T.Run("with multiple flags", func(t *testing.T) { @@ -42,7 +42,7 @@ func TestNew(T *testing.T) { mask := New(permRead, permWrite) - assert.Equal(t, permRead|permWrite, mask.Value()) + test.EqOp(t, permRead|permWrite, mask.Value()) }) T.Run("with duplicate flags", func(t *testing.T) { @@ -50,7 +50,7 @@ func TestNew(T *testing.T) { mask := New(permRead, permRead) - assert.Equal(t, permRead, mask.Value()) + test.EqOp(t, permRead, mask.Value()) }) } @@ -62,7 +62,7 @@ func TestFromValue(T *testing.T) { mask := FromValue(testPerm(0)) - assert.True(t, mask.IsEmpty()) + test.True(t, mask.IsEmpty()) }) T.Run("with specific value", func(t *testing.T) { @@ -70,9 +70,9 @@ func TestFromValue(T *testing.T) { mask := FromValue(testPerm(5)) - assert.True(t, mask.Has(permRead)) - assert.True(t, mask.Has(permDelete)) - assert.False(t, mask.Has(permWrite)) + test.True(t, mask.Has(permRead)) + test.True(t, mask.Has(permDelete)) + test.False(t, mask.Has(permWrite)) }) } @@ -84,7 +84,7 @@ func TestBitmask_Value(T *testing.T) { mask := New(permRead, permDelete) - assert.Equal(t, permRead|permDelete, mask.Value()) + test.EqOp(t, permRead|permDelete, mask.Value()) }) } @@ -97,7 +97,7 @@ func TestBitmask_Set(T *testing.T) { base := New[testPerm]() mask := base.Set(permRead) - assert.True(t, mask.Has(permRead)) + test.True(t, mask.Has(permRead)) }) T.Run("set multiple flags", func(t *testing.T) { @@ -106,8 +106,8 @@ func TestBitmask_Set(T *testing.T) { base := New[testPerm]() mask := base.Set(permRead, permWrite) - assert.True(t, mask.Has(permRead)) - assert.True(t, mask.Has(permWrite)) + test.True(t, mask.Has(permRead)) + test.True(t, mask.Has(permWrite)) }) T.Run("set already set flag", func(t *testing.T) { @@ -116,7 +116,7 @@ func TestBitmask_Set(T *testing.T) { base := New(permRead) mask := base.Set(permRead) - assert.Equal(t, permRead, mask.Value()) + test.EqOp(t, permRead, mask.Value()) }) T.Run("does not mutate original", func(t *testing.T) { @@ -125,7 +125,7 @@ func TestBitmask_Set(T *testing.T) { original := New(permRead) _ = original.Set(permWrite) - assert.False(t, original.Has(permWrite)) + test.False(t, original.Has(permWrite)) }) } @@ -138,8 +138,8 @@ func TestBitmask_Clear(T *testing.T) { base := New(permRead, permWrite) mask := base.Clear(permWrite) - assert.True(t, mask.Has(permRead)) - assert.False(t, mask.Has(permWrite)) + test.True(t, mask.Has(permRead)) + test.False(t, mask.Has(permWrite)) }) T.Run("clear unset flag", func(t *testing.T) { @@ -148,7 +148,7 @@ func TestBitmask_Clear(T *testing.T) { base := New(permRead) mask := base.Clear(permWrite) - assert.Equal(t, permRead, mask.Value()) + test.EqOp(t, permRead, mask.Value()) }) T.Run("clear multiple flags", func(t *testing.T) { @@ -157,9 +157,9 @@ func TestBitmask_Clear(T *testing.T) { base := New(permRead, permWrite, permDelete) mask := base.Clear(permRead, permWrite) - assert.False(t, mask.Has(permRead)) - assert.False(t, mask.Has(permWrite)) - assert.True(t, mask.Has(permDelete)) + test.False(t, mask.Has(permRead)) + test.False(t, mask.Has(permWrite)) + test.True(t, mask.Has(permDelete)) }) T.Run("does not mutate original", func(t *testing.T) { @@ -168,7 +168,7 @@ func TestBitmask_Clear(T *testing.T) { original := New(permRead, permWrite) _ = original.Clear(permWrite) - assert.True(t, original.Has(permWrite)) + test.True(t, original.Has(permWrite)) }) } @@ -181,7 +181,7 @@ func TestBitmask_Toggle(T *testing.T) { base := New[testPerm]() mask := base.Toggle(permRead) - assert.True(t, mask.Has(permRead)) + test.True(t, mask.Has(permRead)) }) T.Run("toggle set flag clears it", func(t *testing.T) { @@ -190,7 +190,7 @@ func TestBitmask_Toggle(T *testing.T) { base := New(permRead) mask := base.Toggle(permRead) - assert.False(t, mask.Has(permRead)) + test.False(t, mask.Has(permRead)) }) T.Run("toggle multiple flags", func(t *testing.T) { @@ -199,8 +199,8 @@ func TestBitmask_Toggle(T *testing.T) { base := New(permRead) mask := base.Toggle(permRead, permWrite) - assert.False(t, mask.Has(permRead)) - assert.True(t, mask.Has(permWrite)) + test.False(t, mask.Has(permRead)) + test.True(t, mask.Has(permWrite)) }) T.Run("does not mutate original", func(t *testing.T) { @@ -209,7 +209,7 @@ func TestBitmask_Toggle(T *testing.T) { original := New(permRead) _ = original.Toggle(permRead) - assert.True(t, original.Has(permRead)) + test.True(t, original.Has(permRead)) }) } @@ -221,7 +221,7 @@ func TestBitmask_Has(T *testing.T) { mask := New(permRead, permWrite) - assert.True(t, mask.Has(permRead)) + test.True(t, mask.Has(permRead)) }) T.Run("returns false for unset flag", func(t *testing.T) { @@ -229,7 +229,7 @@ func TestBitmask_Has(T *testing.T) { mask := New(permRead) - assert.False(t, mask.Has(permWrite)) + test.False(t, mask.Has(permWrite)) }) T.Run("returns false for zero flag", func(t *testing.T) { @@ -237,7 +237,7 @@ func TestBitmask_Has(T *testing.T) { mask := New(permRead) - assert.False(t, mask.Has(0)) + test.False(t, mask.Has(0)) }) T.Run("with empty bitmask", func(t *testing.T) { @@ -245,7 +245,7 @@ func TestBitmask_Has(T *testing.T) { mask := New[testPerm]() - assert.False(t, mask.Has(permRead)) + test.False(t, mask.Has(permRead)) }) } @@ -257,7 +257,7 @@ func TestBitmask_HasAll(T *testing.T) { mask := New(permRead, permWrite, permDelete) - assert.True(t, mask.HasAll(permRead, permWrite)) + test.True(t, mask.HasAll(permRead, permWrite)) }) T.Run("returns false when one flag missing", func(t *testing.T) { @@ -265,7 +265,7 @@ func TestBitmask_HasAll(T *testing.T) { mask := New(permRead) - assert.False(t, mask.HasAll(permRead, permWrite)) + test.False(t, mask.HasAll(permRead, permWrite)) }) T.Run("returns false for empty flags", func(t *testing.T) { @@ -273,7 +273,7 @@ func TestBitmask_HasAll(T *testing.T) { mask := New(permRead) - assert.False(t, mask.HasAll()) + test.False(t, mask.HasAll()) }) T.Run("returns false for zero flag", func(t *testing.T) { @@ -281,7 +281,7 @@ func TestBitmask_HasAll(T *testing.T) { mask := New(permRead) - assert.False(t, mask.HasAll(0)) + test.False(t, mask.HasAll(0)) }) } @@ -293,7 +293,7 @@ func TestBitmask_HasAny(T *testing.T) { mask := New(permRead) - assert.True(t, mask.HasAny(permRead, permWrite)) + test.True(t, mask.HasAny(permRead, permWrite)) }) T.Run("returns false when no flags set", func(t *testing.T) { @@ -301,7 +301,7 @@ func TestBitmask_HasAny(T *testing.T) { mask := New(permRead) - assert.False(t, mask.HasAny(permWrite, permDelete)) + test.False(t, mask.HasAny(permWrite, permDelete)) }) T.Run("returns false for empty flags", func(t *testing.T) { @@ -309,7 +309,7 @@ func TestBitmask_HasAny(T *testing.T) { mask := New(permRead) - assert.False(t, mask.HasAny()) + test.False(t, mask.HasAny()) }) } @@ -321,7 +321,7 @@ func TestBitmask_IsEmpty(T *testing.T) { mask := New[testPerm]() - assert.True(t, mask.IsEmpty()) + test.True(t, mask.IsEmpty()) }) T.Run("returns false for non-empty bitmask", func(t *testing.T) { @@ -329,7 +329,7 @@ func TestBitmask_IsEmpty(T *testing.T) { mask := New(permRead) - assert.False(t, mask.IsEmpty()) + test.False(t, mask.IsEmpty()) }) T.Run("returns true for zero value", func(t *testing.T) { @@ -337,7 +337,7 @@ func TestBitmask_IsEmpty(T *testing.T) { var mask Bitmask[testPerm] - assert.True(t, mask.IsEmpty()) + test.True(t, mask.IsEmpty()) }) } @@ -349,7 +349,7 @@ func TestBitmask_Count(T *testing.T) { mask := New[testPerm]() - assert.Equal(t, 0, mask.Count()) + test.EqOp(t, 0, mask.Count()) }) T.Run("counts one bit", func(t *testing.T) { @@ -357,7 +357,7 @@ func TestBitmask_Count(T *testing.T) { mask := New(permRead) - assert.Equal(t, 1, mask.Count()) + test.EqOp(t, 1, mask.Count()) }) T.Run("counts multiple bits", func(t *testing.T) { @@ -365,7 +365,7 @@ func TestBitmask_Count(T *testing.T) { mask := New(permRead, permWrite, permDelete) - assert.Equal(t, 3, mask.Count()) + test.EqOp(t, 3, mask.Count()) }) T.Run("counts all bits", func(t *testing.T) { @@ -373,7 +373,7 @@ func TestBitmask_Count(T *testing.T) { mask := FromValue(^testPerm(0)) - assert.Equal(t, 8, mask.Count()) + test.EqOp(t, 8, mask.Count()) }) } @@ -387,8 +387,8 @@ func TestBitmask_Union(T *testing.T) { b := New(permWrite) result := a.Union(b) - assert.True(t, result.Has(permRead)) - assert.True(t, result.Has(permWrite)) + test.True(t, result.Has(permRead)) + test.True(t, result.Has(permWrite)) }) T.Run("union with empty", func(t *testing.T) { @@ -398,7 +398,7 @@ func TestBitmask_Union(T *testing.T) { b := New[testPerm]() result := a.Union(b) - assert.Equal(t, a.Value(), result.Value()) + test.EqOp(t, a.Value(), result.Value()) }) T.Run("union with self", func(t *testing.T) { @@ -407,7 +407,7 @@ func TestBitmask_Union(T *testing.T) { a := New(permRead, permWrite) result := a.Union(a) - assert.Equal(t, a.Value(), result.Value()) + test.EqOp(t, a.Value(), result.Value()) }) } @@ -421,9 +421,9 @@ func TestBitmask_Intersect(T *testing.T) { b := New(permWrite, permDelete) result := a.Intersect(b) - assert.False(t, result.Has(permRead)) - assert.True(t, result.Has(permWrite)) - assert.False(t, result.Has(permDelete)) + test.False(t, result.Has(permRead)) + test.True(t, result.Has(permWrite)) + test.False(t, result.Has(permDelete)) }) T.Run("intersect with no overlap", func(t *testing.T) { @@ -433,7 +433,7 @@ func TestBitmask_Intersect(T *testing.T) { b := New(permWrite) result := a.Intersect(b) - assert.True(t, result.IsEmpty()) + test.True(t, result.IsEmpty()) }) T.Run("intersect with self", func(t *testing.T) { @@ -442,7 +442,7 @@ func TestBitmask_Intersect(T *testing.T) { a := New(permRead, permWrite) result := a.Intersect(a) - assert.Equal(t, a.Value(), result.Value()) + test.EqOp(t, a.Value(), result.Value()) }) } @@ -456,9 +456,9 @@ func TestBitmask_Difference(T *testing.T) { b := New(permWrite) result := a.Difference(b) - assert.True(t, result.Has(permRead)) - assert.False(t, result.Has(permWrite)) - assert.True(t, result.Has(permDelete)) + test.True(t, result.Has(permRead)) + test.False(t, result.Has(permWrite)) + test.True(t, result.Has(permDelete)) }) T.Run("difference with no overlap", func(t *testing.T) { @@ -468,7 +468,7 @@ func TestBitmask_Difference(T *testing.T) { b := New(permWrite) result := a.Difference(b) - assert.Equal(t, a.Value(), result.Value()) + test.EqOp(t, a.Value(), result.Value()) }) T.Run("difference with self", func(t *testing.T) { @@ -477,7 +477,7 @@ func TestBitmask_Difference(T *testing.T) { a := New(permRead, permWrite) result := a.Difference(a) - assert.True(t, result.IsEmpty()) + test.True(t, result.IsEmpty()) }) } @@ -489,7 +489,7 @@ func TestBitmask_String(T *testing.T) { mask := New[testPerm]() - assert.Equal(t, "00000000", mask.String()) + test.EqOp(t, "00000000", mask.String()) }) T.Run("single flag", func(t *testing.T) { @@ -497,7 +497,7 @@ func TestBitmask_String(T *testing.T) { mask := New(permRead) - assert.Equal(t, "00000001", mask.String()) + test.EqOp(t, "00000001", mask.String()) }) T.Run("multiple flags", func(t *testing.T) { @@ -505,7 +505,7 @@ func TestBitmask_String(T *testing.T) { mask := New(permRead, permWrite) - assert.Equal(t, "00000011", mask.String()) + test.EqOp(t, "00000011", mask.String()) }) T.Run("all flags", func(t *testing.T) { @@ -513,7 +513,7 @@ func TestBitmask_String(T *testing.T) { mask := FromValue(^testPerm(0)) - assert.Equal(t, "11111111", mask.String()) + test.EqOp(t, "11111111", mask.String()) }) } @@ -526,8 +526,8 @@ func TestBitmask_MarshalJSON(T *testing.T) { mask := New(permRead, permWrite) data, err := json.Marshal(&mask) - require.NoError(t, err) - assert.Equal(t, "3", string(data)) + must.NoError(t, err) + test.EqOp(t, "3", string(data)) }) T.Run("marshals zero", func(t *testing.T) { @@ -536,8 +536,8 @@ func TestBitmask_MarshalJSON(T *testing.T) { mask := New[testPerm]() data, err := json.Marshal(&mask) - require.NoError(t, err) - assert.Equal(t, "0", string(data)) + must.NoError(t, err) + test.EqOp(t, "0", string(data)) }) T.Run("marshals in struct", func(t *testing.T) { @@ -550,8 +550,8 @@ func TestBitmask_MarshalJSON(T *testing.T) { w := wrapper{Perms: New(permRead, permDelete)} data, err := json.Marshal(&w) - require.NoError(t, err) - assert.Equal(t, `{"perms":5}`, string(data)) + must.NoError(t, err) + test.EqOp(t, `{"perms":5}`, string(data)) }) } @@ -564,9 +564,9 @@ func TestBitmask_UnmarshalJSON(T *testing.T) { var mask Bitmask[testPerm] err := json.Unmarshal([]byte("3"), &mask) - require.NoError(t, err) - assert.True(t, mask.Has(permRead)) - assert.True(t, mask.Has(permWrite)) + must.NoError(t, err) + test.True(t, mask.Has(permRead)) + test.True(t, mask.Has(permWrite)) }) T.Run("unmarshals zero", func(t *testing.T) { @@ -575,8 +575,8 @@ func TestBitmask_UnmarshalJSON(T *testing.T) { var mask Bitmask[testPerm] err := json.Unmarshal([]byte("0"), &mask) - require.NoError(t, err) - assert.True(t, mask.IsEmpty()) + must.NoError(t, err) + test.True(t, mask.IsEmpty()) }) T.Run("returns error for invalid input", func(t *testing.T) { @@ -585,7 +585,7 @@ func TestBitmask_UnmarshalJSON(T *testing.T) { var mask Bitmask[testPerm] err := json.Unmarshal([]byte(`"not a number"`), &mask) - assert.Error(t, err) + test.Error(t, err) }) T.Run("returns error for negative number", func(t *testing.T) { @@ -594,7 +594,7 @@ func TestBitmask_UnmarshalJSON(T *testing.T) { var mask Bitmask[testPerm] err := json.Unmarshal([]byte("-1"), &mask) - assert.Error(t, err) + test.Error(t, err) }) T.Run("unmarshals in struct", func(t *testing.T) { @@ -607,9 +607,9 @@ func TestBitmask_UnmarshalJSON(T *testing.T) { var w wrapper err := json.Unmarshal([]byte(`{"perms":5}`), &w) - require.NoError(t, err) - assert.True(t, w.Perms.Has(permRead)) - assert.True(t, w.Perms.Has(permDelete)) + must.NoError(t, err) + test.True(t, w.Perms.Has(permRead)) + test.True(t, w.Perms.Has(permDelete)) }) T.Run("round trip", func(t *testing.T) { @@ -617,13 +617,13 @@ func TestBitmask_UnmarshalJSON(T *testing.T) { original := New(permRead, permWrite, permAdmin) data, err := json.Marshal(&original) - require.NoError(t, err) + must.NoError(t, err) var restored Bitmask[testPerm] err = json.Unmarshal(data, &restored) - require.NoError(t, err) + must.NoError(t, err) - assert.Equal(t, original.Value(), restored.Value()) + test.EqOp(t, original.Value(), restored.Value()) }) } @@ -637,9 +637,9 @@ func TestBitmask_Immutability(T *testing.T) { b := a.Set(permWrite) c := b.Clear(permRead) - assert.Equal(t, permRead, a.Value()) - assert.Equal(t, permRead|permWrite, b.Value()) - assert.Equal(t, permWrite, c.Value()) + test.EqOp(t, permRead, a.Value()) + test.EqOp(t, permRead|permWrite, b.Value()) + test.EqOp(t, permWrite, c.Value()) }) } @@ -659,7 +659,7 @@ func TestBitmask_uint16(T *testing.T) { mask := New(f1, f3) - assert.Equal(t, "0000000000000101", mask.String()) + test.EqOp(t, "0000000000000101", mask.String()) }) T.Run("operations work", func(t *testing.T) { @@ -668,8 +668,8 @@ func TestBitmask_uint16(T *testing.T) { base := New(f1, f2) mask := base.Clear(f1) - assert.False(t, mask.Has(f1)) - assert.True(t, mask.Has(f2)) + test.False(t, mask.Has(f1)) + test.True(t, mask.Has(f2)) }) } @@ -683,7 +683,7 @@ func TestBitmask_uint32(T *testing.T) { mask := FromValue(flag32(0b11110000_00001111)) - assert.Equal(t, 8, mask.Count()) + test.EqOp(t, 8, mask.Count()) }) T.Run("string has 32 digits", func(t *testing.T) { @@ -691,6 +691,6 @@ func TestBitmask_uint32(T *testing.T) { mask := New(flag32(1)) - assert.Equal(t, 32, len(mask.String())) + test.EqOp(t, 32, len(mask.String())) }) } diff --git a/capitalism/config/config_test.go b/capitalism/config/config_test.go index dc67168..3fd08d1 100644 --- a/capitalism/config/config_test.go +++ b/capitalism/config/config_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -24,7 +24,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Stripe: &stripe.Config{APIKey: t.Name()}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("returns nil when not enabled", func(t *testing.T) { @@ -35,7 +35,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Enabled: false, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with invalid config", func(t *testing.T) { @@ -47,7 +47,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: StripeProvider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -63,8 +63,8 @@ func TestProvideCapitalismImplementation(T *testing.T) { } pm, err := ProvideCapitalismImplementation(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg) - require.NoError(t, err) - assert.NotNil(t, pm) + must.NoError(t, err) + test.NotNil(t, pm) }) T.Run("with unknown provider", func(t *testing.T) { @@ -75,7 +75,7 @@ func TestProvideCapitalismImplementation(T *testing.T) { } pm, err := ProvideCapitalismImplementation(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg) - assert.Nil(t, pm) - assert.Error(t, err) + test.Nil(t, pm) + test.Error(t, err) }) } diff --git a/capitalism/config/do_test.go b/capitalism/config/do_test.go index fb9491f..57648e2 100644 --- a/capitalism/config/do_test.go +++ b/capitalism/config/do_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterPaymentManager(T *testing.T) { @@ -30,7 +30,7 @@ func TestRegisterPaymentManager(T *testing.T) { RegisterPaymentManager(i) pm, err := do.Invoke[capitalism.PaymentManager](i) - require.NoError(t, err) - assert.NotNil(t, pm) + must.NoError(t, err) + test.NotNil(t, pm) }) } diff --git a/capitalism/noop/noop_test.go b/capitalism/noop/noop_test.go index 5d4e086..b07b9f1 100644 --- a/capitalism/noop/noop_test.go +++ b/capitalism/noop/noop_test.go @@ -4,8 +4,8 @@ import ( "net/http" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestPaymentManager_HandleEventWebhook(T *testing.T) { @@ -15,9 +15,9 @@ func TestPaymentManager_HandleEventWebhook(T *testing.T) { t.Parallel() mgr := NewPaymentManager() req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://example.com/webhook", http.NoBody) - require.NoError(t, err) + must.NoError(t, err) - assert.NoError(t, mgr.HandleEventWebhook(req)) + test.NoError(t, mgr.HandleEventWebhook(req)) }) } diff --git a/capitalism/stripe/config_test.go b/capitalism/stripe/config_test.go index 227f783..ca067fc 100644 --- a/capitalism/stripe/config_test.go +++ b/capitalism/stripe/config_test.go @@ -3,7 +3,7 @@ package stripe import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestStripeConfig_ValidateWithContext(T *testing.T) { @@ -17,7 +17,7 @@ func TestStripeConfig_ValidateWithContext(T *testing.T) { APIKey: "blah", } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing API key", func(t *testing.T) { @@ -28,6 +28,6 @@ func TestStripeConfig_ValidateWithContext(T *testing.T) { APIKey: "", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/capitalism/stripe/stripe_test.go b/capitalism/stripe/stripe_test.go index 8701ff8..77a3c5d 100644 --- a/capitalism/stripe/stripe_test.go +++ b/capitalism/stripe/stripe_test.go @@ -15,8 +15,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/random" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "github.com/stripe/stripe-go/v75" "github.com/stripe/stripe-go/v75/webhook" ) @@ -35,7 +34,7 @@ func TestNewStripePaymentManager(T *testing.T) { logger := logging.NewNoopLogger() pm := ProvideStripePaymentManager(logger, tracing.NewNoopTracerProvider(), &Config{}) - assert.NotNil(t, pm) + test.NotNil(t, pm) }) T.Run("nil config", func(t *testing.T) { @@ -44,7 +43,7 @@ func TestNewStripePaymentManager(T *testing.T) { logger := logging.NewNoopLogger() pm := ProvideStripePaymentManager(logger, tracing.NewNoopTracerProvider(), nil) - assert.NotNil(t, pm) + test.NotNil(t, pm) }) } @@ -73,8 +72,8 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { } rawMessage, err := json.Marshal(paymentIntent) - require.NoError(t, err) - require.NotNil(t, rawMessage) + must.NoError(t, err) + must.NotNil(t, rawMessage) exampleInput := &stripe.Event{ APIResource: stripe.APIResource{}, @@ -96,8 +95,8 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { jsonBytes := pm.encoderDecoder.MustEncode(ctx, exampleInput) secret, err := random.GenerateHexEncodedString(ctx, 32) - require.NoError(t, err) - require.NotEmpty(t, secret) + must.NoError(t, err) + must.NotEq(t, "", secret) pm.webhookSecret = secret now := time.Now() @@ -108,7 +107,7 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { }) event, err := webhook.ConstructEvent(signedPayload.Payload, signedPayload.Header, signedPayload.Secret) - require.NoError(t, err) + must.NoError(t, err) eventPayload := pm.encoderDecoder.MustEncode(ctx, event) encoderDecoder := &mockencoding.ServerEncoderDecoderMock{ @@ -119,12 +118,12 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { pm.encoderDecoder = encoderDecoder req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", bytes.NewReader(eventPayload)) - require.NoError(t, err) - require.NotNil(t, req) + must.NoError(t, err) + must.NotNil(t, req) req.Header.Set(stripeSignatureHeaderKey, signedPayload.Header) err = pm.HandleEventWebhook(req) - assert.NoError(t, err) + test.NoError(t, err) test.SliceLen(t, 1, encoderDecoder.DecodeBytesCalls()) }) @@ -136,12 +135,12 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { pm := ProvideStripePaymentManager(nil, nil, &Config{}).(*stripePaymentManager) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", http.NoBody) - require.NoError(t, err) - require.NotNil(t, req) + must.NoError(t, err) + must.NotNil(t, req) req.Body = &errReader{} err = pm.HandleEventWebhook(req) - assert.Error(t, err) + test.Error(t, err) }) T.Run("with invalid signature", func(t *testing.T) { @@ -152,12 +151,12 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { pm.webhookSecret = "some_secret" req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", bytes.NewReader([]byte(`{}`))) - require.NoError(t, err) - require.NotNil(t, req) + must.NoError(t, err) + must.NotNil(t, req) req.Header.Set(stripeSignatureHeaderKey, "invalid_signature") err = pm.HandleEventWebhook(req) - assert.Error(t, err) + test.Error(t, err) }) T.Run("with decode error for payment intent", func(t *testing.T) { @@ -169,7 +168,7 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { paymentIntent := &stripe.PaymentIntent{} rawMessage, err := json.Marshal(paymentIntent) - require.NoError(t, err) + must.NoError(t, err) exampleInput := &stripe.Event{ APIVersion: "2023-08-16", @@ -181,8 +180,8 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { jsonBytes := pm.encoderDecoder.MustEncode(ctx, exampleInput) secret, err := random.GenerateHexEncodedString(ctx, 32) - require.NoError(t, err) - require.NotEmpty(t, secret) + must.NoError(t, err) + must.NotEq(t, "", secret) pm.webhookSecret = secret signedPayload := webhook.GenerateTestSignedPayload(&webhook.UnsignedPayload{ @@ -192,7 +191,7 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { }) event, err := webhook.ConstructEvent(signedPayload.Payload, signedPayload.Header, signedPayload.Secret) - require.NoError(t, err) + must.NoError(t, err) eventPayload := pm.encoderDecoder.MustEncode(ctx, event) encoderDecoder := &mockencoding.ServerEncoderDecoderMock{ @@ -203,12 +202,12 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { pm.encoderDecoder = encoderDecoder req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", bytes.NewReader(eventPayload)) - require.NoError(t, err) - require.NotNil(t, req) + must.NoError(t, err) + must.NotNil(t, req) req.Header.Set(stripeSignatureHeaderKey, signedPayload.Header) err = pm.HandleEventWebhook(req) - assert.Error(t, err) + test.Error(t, err) test.SliceLen(t, 1, encoderDecoder.DecodeBytesCalls()) }) @@ -229,8 +228,8 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { jsonBytes := pm.encoderDecoder.MustEncode(ctx, exampleInput) secret, err := random.GenerateHexEncodedString(ctx, 32) - require.NoError(t, err) - require.NotEmpty(t, secret) + must.NoError(t, err) + must.NotEq(t, "", secret) pm.webhookSecret = secret signedPayload := webhook.GenerateTestSignedPayload(&webhook.UnsignedPayload{ @@ -240,15 +239,15 @@ func Test_stripePaymentManager_HandleSubscriptionEventWebhook(T *testing.T) { }) event, err := webhook.ConstructEvent(signedPayload.Payload, signedPayload.Header, signedPayload.Secret) - require.NoError(t, err) + must.NoError(t, err) eventPayload := pm.encoderDecoder.MustEncode(ctx, event) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", bytes.NewReader(eventPayload)) - require.NoError(t, err) - require.NotNil(t, req) + must.NoError(t, err) + must.NotNil(t, req) req.Header.Set(stripeSignatureHeaderKey, signedPayload.Header) err = pm.HandleEventWebhook(req) - assert.NoError(t, err) + test.NoError(t, err) }) } diff --git a/circuitbreaking/config/config_test.go b/circuitbreaking/config/config_test.go index 40581ed..a1f763b 100644 --- a/circuitbreaking/config/config_test.go +++ b/circuitbreaking/config/config_test.go @@ -14,7 +14,6 @@ import ( circuit "github.com/rubyist/circuitbreaker" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel/metric" ) @@ -32,7 +31,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("with missing name", func(t *testing.T) { @@ -45,7 +44,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - assert.Error(t, err) + test.Error(t, err) }) T.Run("with error rate exceeding max", func(t *testing.T) { @@ -58,7 +57,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - assert.Error(t, err) + test.Error(t, err) }) } @@ -71,9 +70,9 @@ func TestConfig_EnsureDefaults(T *testing.T) { cfg := &Config{} cfg.EnsureDefaults() - assert.Equal(t, "UNKNOWN", cfg.Name) - assert.Equal(t, float64(100), cfg.ErrorRate) - assert.Equal(t, uint64(1_000_000), cfg.MinimumSampleThreshold) + test.EqOp(t, "UNKNOWN", cfg.Name) + test.EqOp(t, float64(100), cfg.ErrorRate) + test.EqOp(t, uint64(1_000_000), cfg.MinimumSampleThreshold) }) T.Run("does not override set values", func(t *testing.T) { @@ -86,9 +85,9 @@ func TestConfig_EnsureDefaults(T *testing.T) { } cfg.EnsureDefaults() - assert.Equal(t, "test", cfg.Name) - assert.Equal(t, 50.0, cfg.ErrorRate) - assert.Equal(t, uint64(500), cfg.MinimumSampleThreshold) + test.EqOp(t, "test", cfg.Name) + test.EqOp(t, 50.0, cfg.ErrorRate) + test.EqOp(t, uint64(500), cfg.MinimumSampleThreshold) }) } @@ -101,8 +100,8 @@ func TestProvideCircuitBreakerFromConfig(T *testing.T) { ctx := t.Context() cb, err := ProvideCircuitBreakerFromConfig(ctx, cfg, logging.NewNoopLogger(), metrics.NewNoopMetricsProvider()) - assert.NotNil(t, cb) - assert.NoError(t, err) + test.NotNil(t, cb) + test.NoError(t, err) }) T.Run("with error providing first metric", func(t *testing.T) { @@ -119,8 +118,8 @@ func TestProvideCircuitBreakerFromConfig(T *testing.T) { } cb, err := ProvideCircuitBreakerFromConfig(ctx, cfg, logging.NewNoopLogger(), mp) - assert.Nil(t, cb) - assert.Error(t, err) + test.Nil(t, cb) + test.Error(t, err) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -145,8 +144,8 @@ func TestProvideCircuitBreakerFromConfig(T *testing.T) { } cb, err := ProvideCircuitBreakerFromConfig(ctx, cfg, logging.NewNoopLogger(), mp) - assert.Nil(t, cb) - assert.Error(t, err) + test.Nil(t, cb) + test.Error(t, err) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -172,8 +171,8 @@ func TestProvideCircuitBreakerFromConfig(T *testing.T) { } cb, err := ProvideCircuitBreakerFromConfig(ctx, cfg, logging.NewNoopLogger(), mp) - assert.Nil(t, cb) - assert.Error(t, err) + test.Nil(t, cb) + test.Error(t, err) test.SliceLen(t, 3, mp.NewInt64CounterCalls()) }) @@ -183,13 +182,13 @@ func TestProvideCircuitBreakerFromConfig(T *testing.T) { func TestEnsureCircuitBreaker(T *testing.T) { T.Run("with nil breaker", func(t *testing.T) { actual := EnsureCircuitBreaker(nil) - assert.NotNil(t, actual) + test.NotNil(t, actual) }) T.Run("with non-nil breaker", func(t *testing.T) { input := noop.NewCircuitBreaker() actual := EnsureCircuitBreaker(input) - assert.Equal(t, input, actual) + test.Eq(t, input, actual) }) } @@ -200,8 +199,8 @@ func TestConfig_ProvideCircuitBreaker(T *testing.T) { var cfg *Config cb, err := cfg.ProvideCircuitBreaker(ctx, logging.NewNoopLogger(), metrics.NewNoopMetricsProvider()) - assert.Nil(t, cb) - assert.Error(t, err) + test.Nil(t, cb) + test.Error(t, err) }) T.Run("with invalid config", func(t *testing.T) { @@ -213,8 +212,8 @@ func TestConfig_ProvideCircuitBreaker(T *testing.T) { } cb, err := cfg.ProvideCircuitBreaker(ctx, logging.NewNoopLogger(), metrics.NewNoopMetricsProvider()) - assert.NotNil(t, cb) - assert.NoError(t, err) + test.NotNil(t, cb) + test.NoError(t, err) }) } @@ -230,8 +229,8 @@ func TestBaseImplementation(T *testing.T) { } cb, err := cfg.ProvideCircuitBreaker(ctx, logging.NewNoopLogger(), metrics.NewNoopMetricsProvider()) - assert.NotNil(t, cb) - assert.NoError(t, err) + test.NotNil(t, cb) + test.NoError(t, err) cb.Failed() }) @@ -246,8 +245,8 @@ func TestBaseImplementation(T *testing.T) { } cb, err := cfg.ProvideCircuitBreaker(ctx, logging.NewNoopLogger(), metrics.NewNoopMetricsProvider()) - assert.NotNil(t, cb) - assert.NoError(t, err) + test.NotNil(t, cb) + test.NoError(t, err) cb.Succeeded() }) @@ -262,10 +261,10 @@ func TestBaseImplementation(T *testing.T) { } cb, err := cfg.ProvideCircuitBreaker(ctx, logging.NewNoopLogger(), metrics.NewNoopMetricsProvider()) - assert.NotNil(t, cb) - assert.NoError(t, err) + test.NotNil(t, cb) + test.NoError(t, err) - assert.True(t, cb.CanProceed()) + test.True(t, cb.CanProceed()) }) T.Run("CannotProceed", func(t *testing.T) { @@ -278,10 +277,10 @@ func TestBaseImplementation(T *testing.T) { } cb, err := cfg.ProvideCircuitBreaker(ctx, logging.NewNoopLogger(), metrics.NewNoopMetricsProvider()) - assert.NotNil(t, cb) - assert.NoError(t, err) + test.NotNil(t, cb) + test.NoError(t, err) - assert.False(t, cb.CannotProceed()) + test.False(t, cb.CannotProceed()) }) } @@ -306,11 +305,11 @@ func TestHandleCircuitBreakerEvents(T *testing.T) { } failure, err := mp.NewInt64Counter("failure") - assert.NoError(t, err) + test.NoError(t, err) reset, err := mp.NewInt64Counter("reset") - assert.NoError(t, err) + test.NoError(t, err) broken, err := mp.NewInt64Counter("broken") - assert.NoError(t, err) + test.NoError(t, err) events := make(chan circuit.BreakerEvent, 4) events <- circuit.BreakerTripped @@ -340,20 +339,22 @@ func TestCircuitBreaker_Integration(T *testing.T) { } cb, err := ProvideCircuitBreakerFromConfig(ctx, cfg, logging.NewNoopLogger(), metrics.NewNoopMetricsProvider()) - assert.NotNil(t, cb) - assert.NoError(t, err) + test.NotNil(t, cb) + test.NoError(t, err) - assert.True(t, cb.CanProceed()) + test.True(t, cb.CanProceed()) cb.Failed() - assert.True(t, cb.CannotProceed()) + test.True(t, cb.CannotProceed()) cb.Succeeded() - assert.Eventually( - t, - func() bool { - return cb.CanProceed() - }, - 5*time.Second, - 500*time.Millisecond, - ) + deadline := time.Now().Add(5 * time.Second) + var proceeded bool + for time.Now().Before(deadline) { + if cb.CanProceed() { + proceeded = true + break + } + time.Sleep(500 * time.Millisecond) + } + test.True(t, proceeded) }) } diff --git a/circuitbreaking/config/do_test.go b/circuitbreaking/config/do_test.go index 184dbac..b6882be 100644 --- a/circuitbreaking/config/do_test.go +++ b/circuitbreaking/config/do_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) //nolint:paralleltest // race condition in the core circuit breaker library, I think? @@ -27,7 +27,7 @@ func TestRegisterCircuitBreaker(T *testing.T) { RegisterCircuitBreaker(i) cb, err := do.Invoke[circuitbreaking.CircuitBreaker](i) - require.NoError(t, err) - assert.NotNil(t, cb) + must.NoError(t, err) + test.NotNil(t, cb) }) } diff --git a/circuitbreaking/noop/noop_test.go b/circuitbreaking/noop/noop_test.go index 8c6d402..f98226e 100644 --- a/circuitbreaking/noop/noop_test.go +++ b/circuitbreaking/noop/noop_test.go @@ -3,7 +3,7 @@ package noop import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestNewCircuitBreaker(T *testing.T) { @@ -13,7 +13,7 @@ func TestNewCircuitBreaker(T *testing.T) { t.Parallel() x := NewCircuitBreaker() - assert.NotNil(t, x) + test.NotNil(t, x) }) } @@ -46,7 +46,7 @@ func TestCircuitBreaker_CanProceed(T *testing.T) { t.Parallel() x := NewCircuitBreaker() - assert.True(t, x.CanProceed()) + test.True(t, x.CanProceed()) }) } @@ -57,6 +57,6 @@ func TestCircuitBreaker_CannotProceed(T *testing.T) { t.Parallel() x := NewCircuitBreaker() - assert.False(t, x.CannotProceed()) + test.False(t, x.CannotProceed()) }) } diff --git a/compression/compressor_test.go b/compression/compressor_test.go index 295af53..bb7ecd6 100644 --- a/compression/compressor_test.go +++ b/compression/compressor_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type whatever struct { @@ -23,24 +23,24 @@ func TestNewCompressor(T *testing.T) { t.Parallel() comp, err := NewCompressor(algoZstd) - require.NoError(t, err) - require.NotNil(t, comp) + must.NoError(t, err) + must.NotNil(t, comp) }) T.Run("s2", func(t *testing.T) { t.Parallel() comp, err := NewCompressor(algoS2) - require.NoError(t, err) - require.NotNil(t, comp) + must.NoError(t, err) + must.NotNil(t, comp) }) T.Run("invalid algo", func(t *testing.T) { t.Parallel() comp, err := NewCompressor(algo(t.Name())) - require.Error(t, err) - require.Nil(t, comp) + must.Error(t, err) + must.Nil(t, comp) }) } @@ -52,7 +52,7 @@ func Test_compressor_CompressBytes(T *testing.T) { ctx := t.Context() comp, err := NewCompressor(algoZstd) - require.NoError(t, err) + must.NoError(t, err) x := &whatever{ Name: "testing", @@ -62,10 +62,10 @@ func Test_compressor_CompressBytes(T *testing.T) { expected := "KLUv_QQAmQAAeyJuYW1lIjoidGVzdGluZyJ9Ch6HXww=" compressed, err := comp.CompressBytes(encoder.MustEncodeJSON(ctx, x)) - assert.NoError(t, err) + test.NoError(t, err) actual := base64.URLEncoding.EncodeToString(compressed) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) T.Run("s2", func(t *testing.T) { @@ -73,7 +73,7 @@ func Test_compressor_CompressBytes(T *testing.T) { ctx := t.Context() comp, err := NewCompressor(algoS2) - require.NoError(t, err) + must.NoError(t, err) x := &whatever{ Name: "testing", @@ -83,10 +83,10 @@ func Test_compressor_CompressBytes(T *testing.T) { expected := "_wYAAFMyc1R3TwEXAABui7jXeyJuYW1lIjoidGVzdGluZyJ9Cg==" compressed, err := comp.CompressBytes(encoder.MustEncodeJSON(ctx, x)) - assert.NoError(t, err) + test.NoError(t, err) actual := base64.URLEncoding.EncodeToString(compressed) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) T.Run("invalid algo", func(t *testing.T) { @@ -94,7 +94,7 @@ func Test_compressor_CompressBytes(T *testing.T) { ctx := t.Context() comp, err := NewCompressor(algoS2) - require.NoError(t, err) + must.NoError(t, err) comp.(*compressor).algo = "invalid" @@ -105,8 +105,8 @@ func Test_compressor_CompressBytes(T *testing.T) { encoder := encoding.ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), encoding.ContentTypeJSON) compressed, err := comp.CompressBytes(encoder.MustEncodeJSON(ctx, x)) - assert.Error(t, err) - assert.Nil(t, compressed) + test.Error(t, err) + test.Nil(t, compressed) }) } @@ -124,7 +124,7 @@ func Test_compressor_DecompressBytes(T *testing.T) { ctx := t.Context() comp, err := NewCompressor(a) - require.NoError(t, err) + must.NoError(t, err) x := &whatever{ Name: "testing", @@ -133,15 +133,15 @@ func Test_compressor_DecompressBytes(T *testing.T) { encoder := encoding.ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), encoding.ContentTypeJSON) compressed, err := comp.CompressBytes(encoder.MustEncodeJSON(ctx, x)) - assert.NoError(t, err) + test.NoError(t, err) decompressed, err := comp.DecompressBytes(compressed) - assert.NoError(t, err) + test.NoError(t, err) var y *whatever - require.NoError(t, encoder.DecodeBytes(ctx, decompressed, &y)) + must.NoError(t, encoder.DecodeBytes(ctx, decompressed, &y)) - assert.Equal(t, x, y) + test.Eq(t, x, y) }) } @@ -150,7 +150,7 @@ func Test_compressor_DecompressBytes(T *testing.T) { ctx := t.Context() comp, err := NewCompressor(algoZstd) - require.NoError(t, err) + must.NoError(t, err) x := &whatever{ Name: "testing", @@ -159,34 +159,34 @@ func Test_compressor_DecompressBytes(T *testing.T) { encoder := encoding.ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), encoding.ContentTypeJSON) compressed, err := comp.CompressBytes(encoder.MustEncodeJSON(ctx, x)) - assert.NoError(t, err) + test.NoError(t, err) comp.(*compressor).algo = "invalid" decompressed, err := comp.DecompressBytes(compressed) - assert.Error(t, err) - assert.Nil(t, decompressed) + test.Error(t, err) + test.Nil(t, decompressed) }) T.Run("with invalid zstd data", func(t *testing.T) { t.Parallel() comp, err := NewCompressor(algoZstd) - require.NoError(t, err) + must.NoError(t, err) decompressed, err := comp.DecompressBytes([]byte("not valid zstd data")) - assert.Error(t, err) - assert.Nil(t, decompressed) + test.Error(t, err) + test.Nil(t, decompressed) }) T.Run("with invalid s2 data", func(t *testing.T) { t.Parallel() comp, err := NewCompressor(algoS2) - require.NoError(t, err) + must.NoError(t, err) decompressed, err := comp.DecompressBytes([]byte("not valid s2 data")) - assert.Error(t, err) - assert.Nil(t, decompressed) + test.Error(t, err) + test.Nil(t, decompressed) }) } diff --git a/cryptography/encryption/aes/aes_test.go b/cryptography/encryption/aes/aes_test.go index 38a7e71..f8e55bf 100644 --- a/cryptography/encryption/aes/aes_test.go +++ b/cryptography/encryption/aes/aes_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/verygoodsoftwarenotvirus/platform/v5/random" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestStandardEncryptor(T *testing.T) { @@ -20,18 +20,18 @@ func TestStandardEncryptor(T *testing.T) { ctx := t.Context() expected := t.Name() secret, err := random.GenerateHexEncodedString(ctx, 16) - require.NoError(t, err) + must.NoError(t, err) encryptor, err := NewEncryptorDecryptor(tracing.NewNoopTracerProvider(), logging.NewNoopLogger(), []byte(secret)) - require.NotNil(t, encryptor) - require.NoError(t, err) + must.NotNil(t, encryptor) + must.NoError(t, err) encrypted, err := encryptor.Encrypt(ctx, expected) - assert.NoError(t, err) - assert.NotEmpty(t, encrypted) + test.NoError(t, err) + test.NotEq(t, "", encrypted) actual, err := encryptor.Decrypt(ctx, encrypted) - assert.NoError(t, err) - assert.Equal(t, expected, actual) + test.NoError(t, err) + test.EqOp(t, expected, actual) }) } diff --git a/cryptography/encryption/config/config_test.go b/cryptography/encryption/config/config_test.go index bed65ab..47a60e9 100644 --- a/cryptography/encryption/config/config_test.go +++ b/cryptography/encryption/config/config_test.go @@ -6,7 +6,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) const testKey = "blahblahblahblahblahblahblahblah" @@ -19,7 +19,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{Provider: ProviderAES} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("salsa20 provider", func(t *testing.T) { @@ -27,7 +27,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{Provider: ProviderSalsa20} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("empty provider", func(t *testing.T) { @@ -35,7 +35,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("invalid provider", func(t *testing.T) { @@ -43,7 +43,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{Provider: "invalid"} - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -58,31 +58,31 @@ func TestProvideEncryptorDecryptor(T *testing.T) { t.Parallel() encDec, err := ProvideEncryptorDecryptor(&Config{Provider: ProviderAES}, tracerProvider, logger, key) - assert.NoError(t, err) - assert.NotNil(t, encDec) + test.NoError(t, err) + test.NotNil(t, encDec) }) T.Run("salsa20 provider", func(t *testing.T) { t.Parallel() encDec, err := ProvideEncryptorDecryptor(&Config{Provider: ProviderSalsa20}, tracerProvider, logger, key) - assert.NoError(t, err) - assert.NotNil(t, encDec) + test.NoError(t, err) + test.NotNil(t, encDec) }) T.Run("empty provider defaults to salsa20", func(t *testing.T) { t.Parallel() encDec, err := ProvideEncryptorDecryptor(&Config{}, tracerProvider, logger, key) - assert.NoError(t, err) - assert.NotNil(t, encDec) + test.NoError(t, err) + test.NotNil(t, encDec) }) T.Run("invalid provider defaults to salsa20", func(t *testing.T) { t.Parallel() encDec, err := ProvideEncryptorDecryptor(&Config{Provider: "invalid"}, tracerProvider, logger, key) - assert.NoError(t, err) - assert.NotNil(t, encDec) + test.NoError(t, err) + test.NotNil(t, encDec) }) } diff --git a/cryptography/encryption/config/do_test.go b/cryptography/encryption/config/do_test.go index 7727336..f4460f5 100644 --- a/cryptography/encryption/config/do_test.go +++ b/cryptography/encryption/config/do_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterEncryptorDecryptor(T *testing.T) { @@ -27,7 +27,7 @@ func TestRegisterEncryptorDecryptor(T *testing.T) { RegisterEncryptorDecryptor(i) encDec, err := do.Invoke[encryption.EncryptorDecryptor](i) - require.NoError(t, err) - assert.NotNil(t, encDec) + must.NoError(t, err) + test.NotNil(t, encDec) }) } diff --git a/cryptography/encryption/errors_test.go b/cryptography/encryption/errors_test.go index 11cb05e..c63e770 100644 --- a/cryptography/encryption/errors_test.go +++ b/cryptography/encryption/errors_test.go @@ -3,7 +3,7 @@ package encryption import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestErrIncorrectKeyLength(T *testing.T) { @@ -12,7 +12,7 @@ func TestErrIncorrectKeyLength(T *testing.T) { T.Run("is not nil", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrIncorrectKeyLength) - assert.Equal(t, "secret is not the right length", ErrIncorrectKeyLength.Error()) + test.NotNil(t, ErrIncorrectKeyLength) + test.EqError(t, ErrIncorrectKeyLength, "secret is not the right length") }) } diff --git a/cryptography/encryption/salsa20/salsa20_test.go b/cryptography/encryption/salsa20/salsa20_test.go index d952dd5..43b0a8c 100644 --- a/cryptography/encryption/salsa20/salsa20_test.go +++ b/cryptography/encryption/salsa20/salsa20_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/verygoodsoftwarenotvirus/platform/v5/random" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestStandardEncryptor(T *testing.T) { @@ -20,24 +20,24 @@ func TestStandardEncryptor(T *testing.T) { ctx := t.Context() expected := t.Name() secret, err := random.GenerateHexEncodedString(ctx, 16) - require.NoError(t, err) + must.NoError(t, err) encryptor, err := NewEncryptorDecryptor(tracing.NewNoopTracerProvider(), logging.NewNoopLogger(), []byte(secret)) - require.NotNil(t, encryptor) - require.NoError(t, err) + must.NotNil(t, encryptor) + must.NoError(t, err) encrypted, err := encryptor.Encrypt(ctx, expected) - assert.NoError(t, err) - assert.NotEmpty(t, encrypted) + test.NoError(t, err) + test.NotEq(t, "", encrypted) encrypted2, err := encryptor.Encrypt(ctx, expected) - assert.NoError(t, err) - assert.NotEmpty(t, encrypted2) + test.NoError(t, err) + test.NotEq(t, "", encrypted2) - assert.Equal(t, encrypted, encrypted2) + test.EqOp(t, encrypted, encrypted2) actual, err := encryptor.Decrypt(ctx, encrypted) - assert.NoError(t, err) - assert.Equal(t, expected, actual) + test.NoError(t, err) + test.EqOp(t, expected, actual) }) } diff --git a/cryptography/hashing/adler32/adler32_test.go b/cryptography/hashing/adler32/adler32_test.go index 92aa989..02ac402 100644 --- a/cryptography/hashing/adler32/adler32_test.go +++ b/cryptography/hashing/adler32/adler32_test.go @@ -3,7 +3,7 @@ package adler32 import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func Test_adler32Hasher_Hash(T *testing.T) { @@ -15,7 +15,7 @@ func Test_adler32Hasher_Hash(T *testing.T) { hasher := NewAdler32Hasher() result, err := hasher.Hash(t.Name()) - assert.NoError(t, err) - assert.Equal(t, "546573745f61646c657233324861736865725f486173682f7374616e6461726400000001", result) + test.NoError(t, err) + test.EqOp(t, "546573745f61646c657233324861736865725f486173682f7374616e6461726400000001", result) }) } diff --git a/cryptography/hashing/crc64/crc64_test.go b/cryptography/hashing/crc64/crc64_test.go index cec064e..25292fa 100644 --- a/cryptography/hashing/crc64/crc64_test.go +++ b/cryptography/hashing/crc64/crc64_test.go @@ -3,7 +3,7 @@ package crc64 import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func Test_crc64Hasher_Hash(T *testing.T) { @@ -15,7 +15,7 @@ func Test_crc64Hasher_Hash(T *testing.T) { hasher := NewCRC64Hasher() result, err := hasher.Hash(t.Name()) - assert.NoError(t, err) - assert.Equal(t, "546573745f63726336344861736865725f486173682f7374616e646172640000000000000000", result) + test.NoError(t, err) + test.EqOp(t, "546573745f63726336344861736865725f486173682f7374616e646172640000000000000000", result) }) } diff --git a/cryptography/hashing/fnv/fnv_test.go b/cryptography/hashing/fnv/fnv_test.go index f5d03f4..5b252a6 100644 --- a/cryptography/hashing/fnv/fnv_test.go +++ b/cryptography/hashing/fnv/fnv_test.go @@ -3,7 +3,7 @@ package fnv import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func Test_fnvHasher_Hash(T *testing.T) { @@ -15,7 +15,7 @@ func Test_fnvHasher_Hash(T *testing.T) { hasher := NewFNVHasher() result, err := hasher.Hash(t.Name()) - assert.NoError(t, err) - assert.Equal(t, "546573745f666e764861736865725f486173682f7374616e646172646c62272e07bb014262b821756295c58d", result) + test.NoError(t, err) + test.EqOp(t, "546573745f666e764861736865725f486173682f7374616e646172646c62272e07bb014262b821756295c58d", result) }) } diff --git a/cryptography/hashing/sha256/sha256_test.go b/cryptography/hashing/sha256/sha256_test.go index c0b7ba4..bc4630a 100644 --- a/cryptography/hashing/sha256/sha256_test.go +++ b/cryptography/hashing/sha256/sha256_test.go @@ -3,7 +3,7 @@ package sha256 import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func Test_sha256Hasher_Hash(T *testing.T) { @@ -15,7 +15,7 @@ func Test_sha256Hasher_Hash(T *testing.T) { hasher := NewSHA256Hasher() result, err := hasher.Hash(t.Name()) - assert.NoError(t, err) - assert.Equal(t, "546573745f7368613235364861736865725f486173682f7374616e64617264e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", result) + test.NoError(t, err) + test.EqOp(t, "546573745f7368613235364861736865725f486173682f7374616e64617264e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", result) }) } diff --git a/cryptography/hashing/sha512/sha512_test.go b/cryptography/hashing/sha512/sha512_test.go index 53e5e6d..b56e20d 100644 --- a/cryptography/hashing/sha512/sha512_test.go +++ b/cryptography/hashing/sha512/sha512_test.go @@ -3,7 +3,7 @@ package sha512 import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func Test_sha512Hasher_Hash(T *testing.T) { @@ -15,7 +15,7 @@ func Test_sha512Hasher_Hash(T *testing.T) { hasher := NewSHA512Hasher() result, err := hasher.Hash(t.Name()) - assert.NoError(t, err) - assert.Equal(t, "546573745f7368613531324861736865725f486173682f7374616e64617264cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", result) + test.NoError(t, err) + test.EqOp(t, "546573745f7368613531324861736865725f486173682f7374616e64617264cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", result) }) } diff --git a/database/config/config_test.go b/database/config/config_test.go index 51ed69f..c4f5021 100644 --- a/database/config/config_test.go +++ b/database/config/config_test.go @@ -3,6 +3,7 @@ package databasecfg import ( "context" "database/sql" + "errors" "testing" "time" @@ -10,10 +11,12 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) +var errStubMigrator = errors.New("stub migrator error") + type stubMigrator struct { err error called bool @@ -41,7 +44,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } @@ -55,13 +58,13 @@ func TestConnectionDetails_LoadFromURL(T *testing.T) { d := &ConnectionDetails{} - assert.NoError(t, d.LoadFromURL(exampleURI)) + test.NoError(t, d.LoadFromURL(exampleURI)) - assert.Equal(t, d.Username, "dbuser") - assert.Equal(t, d.Password, "hunter2") - assert.Equal(t, d.Host, "pgdatabase") - assert.Equal(t, d.Database, "database") - assert.Equal(t, d.DisableSSL, true) + test.EqOp(t, d.Username, "dbuser") + test.EqOp(t, d.Password, "hunter2") + test.EqOp(t, d.Host, "pgdatabase") + test.EqOp(t, d.Database, "database") + test.EqOp(t, d.DisableSSL, true) }) T.Run("with invalid port", func(t *testing.T) { @@ -71,7 +74,7 @@ func TestConnectionDetails_LoadFromURL(T *testing.T) { d := &ConnectionDetails{} - assert.Error(t, d.LoadFromURL(exampleURI)) + test.Error(t, d.LoadFromURL(exampleURI)) }) T.Run("with invalid URL", func(t *testing.T) { @@ -79,7 +82,7 @@ func TestConnectionDetails_LoadFromURL(T *testing.T) { d := &ConnectionDetails{} - assert.Error(t, d.LoadFromURL("://not-a-url")) + test.Error(t, d.LoadFromURL("://not-a-url")) }) T.Run("with missing port", func(t *testing.T) { @@ -87,7 +90,7 @@ func TestConnectionDetails_LoadFromURL(T *testing.T) { d := &ConnectionDetails{} - assert.Error(t, d.LoadFromURL("postgres://dbuser:hunter2@pgdatabase/database")) + test.Error(t, d.LoadFromURL("postgres://dbuser:hunter2@pgdatabase/database")) }) T.Run("without sslmode disable", func(t *testing.T) { @@ -96,9 +99,9 @@ func TestConnectionDetails_LoadFromURL(T *testing.T) { exampleURI := "postgres://dbuser:hunter2@pgdatabase:5432/database" d := &ConnectionDetails{} - require.NoError(t, d.LoadFromURL(exampleURI)) + must.NoError(t, d.LoadFromURL(exampleURI)) - assert.False(t, d.DisableSSL) + test.False(t, d.DisableSSL) }) } @@ -111,11 +114,11 @@ func TestConfig_EnsureDefaults(T *testing.T) { cfg := &Config{} cfg.EnsureDefaults() - assert.Equal(t, ProviderPostgres, cfg.Provider) - assert.Equal(t, defaultPingWaitPeriod, cfg.PingWaitPeriod) - assert.Equal(t, defaultConnMaxLifetime, cfg.ConnMaxLifetime) - assert.Equal(t, uint16(defaultMaxIdleConns), cfg.MaxIdleConns) - assert.Equal(t, uint16(defaultMaxOpenConns), cfg.MaxOpenConns) + test.EqOp(t, ProviderPostgres, cfg.Provider) + test.EqOp(t, defaultPingWaitPeriod, cfg.PingWaitPeriod) + test.EqOp(t, defaultConnMaxLifetime, cfg.ConnMaxLifetime) + test.EqOp(t, uint16(defaultMaxIdleConns), cfg.MaxIdleConns) + test.EqOp(t, uint16(defaultMaxOpenConns), cfg.MaxOpenConns) }) T.Run("does not override set values", func(t *testing.T) { @@ -130,11 +133,11 @@ func TestConfig_EnsureDefaults(T *testing.T) { } cfg.EnsureDefaults() - assert.Equal(t, "custom", cfg.Provider) - assert.Equal(t, 5*time.Second, cfg.PingWaitPeriod) - assert.Equal(t, 1*time.Hour, cfg.ConnMaxLifetime) - assert.Equal(t, uint16(10), cfg.MaxIdleConns) - assert.Equal(t, uint16(20), cfg.MaxOpenConns) + test.EqOp(t, "custom", cfg.Provider) + test.EqOp(t, 5*time.Second, cfg.PingWaitPeriod) + test.EqOp(t, 1*time.Hour, cfg.ConnMaxLifetime) + test.EqOp(t, uint16(10), cfg.MaxIdleConns) + test.EqOp(t, uint16(20), cfg.MaxOpenConns) }) } @@ -155,7 +158,7 @@ func TestConfig_GetReadConnectionString(T *testing.T) { } expected := "user=user password=pass database=db host=localhost port=5432" - assert.Equal(t, expected, cfg.GetReadConnectionString()) + test.EqOp(t, expected, cfg.GetReadConnectionString()) }) } @@ -176,7 +179,7 @@ func TestConfig_GetWriteConnectionString(T *testing.T) { } expected := "user=writer password=secret database=mydb host=writehost port=5433" - assert.Equal(t, expected, cfg.GetWriteConnectionString()) + test.EqOp(t, expected, cfg.GetWriteConnectionString()) }) } @@ -187,14 +190,14 @@ func TestConfig_GetMaxPingAttempts(T *testing.T) { t.Parallel() cfg := &Config{MaxPingAttempts: 42} - assert.Equal(t, uint64(42), cfg.GetMaxPingAttempts()) + test.EqOp(t, uint64(42), cfg.GetMaxPingAttempts()) }) T.Run("zero value", func(t *testing.T) { t.Parallel() cfg := &Config{} - assert.Equal(t, uint64(0), cfg.GetMaxPingAttempts()) + test.EqOp(t, uint64(0), cfg.GetMaxPingAttempts()) }) } @@ -205,7 +208,7 @@ func TestConfig_GetPingWaitPeriod(T *testing.T) { t.Parallel() cfg := &Config{PingWaitPeriod: 3 * time.Second} - assert.Equal(t, 3*time.Second, cfg.GetPingWaitPeriod()) + test.EqOp(t, 3*time.Second, cfg.GetPingWaitPeriod()) }) } @@ -216,14 +219,14 @@ func TestConfig_GetMaxIdleConns(T *testing.T) { t.Parallel() cfg := &Config{} - assert.Equal(t, 5, cfg.GetMaxIdleConns()) + test.EqOp(t, 5, cfg.GetMaxIdleConns()) }) T.Run("returns set value", func(t *testing.T) { t.Parallel() cfg := &Config{MaxIdleConns: 12} - assert.Equal(t, 12, cfg.GetMaxIdleConns()) + test.EqOp(t, 12, cfg.GetMaxIdleConns()) }) } @@ -234,14 +237,14 @@ func TestConfig_GetMaxOpenConns(T *testing.T) { t.Parallel() cfg := &Config{} - assert.Equal(t, 7, cfg.GetMaxOpenConns()) + test.EqOp(t, 7, cfg.GetMaxOpenConns()) }) T.Run("returns set value", func(t *testing.T) { t.Parallel() cfg := &Config{MaxOpenConns: 15} - assert.Equal(t, 15, cfg.GetMaxOpenConns()) + test.EqOp(t, 15, cfg.GetMaxOpenConns()) }) } @@ -252,21 +255,21 @@ func TestConfig_GetConnMaxLifetime(T *testing.T) { t.Parallel() cfg := &Config{} - assert.Equal(t, 30*time.Minute, cfg.GetConnMaxLifetime()) + test.EqOp(t, 30*time.Minute, cfg.GetConnMaxLifetime()) }) T.Run("returns default when negative", func(t *testing.T) { t.Parallel() cfg := &Config{ConnMaxLifetime: -1 * time.Second} - assert.Equal(t, 30*time.Minute, cfg.GetConnMaxLifetime()) + test.EqOp(t, 30*time.Minute, cfg.GetConnMaxLifetime()) }) T.Run("returns set value", func(t *testing.T) { t.Parallel() cfg := &Config{ConnMaxLifetime: 1 * time.Hour} - assert.Equal(t, 1*time.Hour, cfg.GetConnMaxLifetime()) + test.EqOp(t, 1*time.Hour, cfg.GetConnMaxLifetime()) }) } @@ -293,7 +296,7 @@ func TestConfig_ValidateWithContext_additional(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) } @@ -304,20 +307,20 @@ func TestConfig_LoadConnectionDetailsFromURL(T *testing.T) { t.Parallel() cfg := &Config{} - require.NoError(t, cfg.LoadConnectionDetailsFromURL("postgres://u:p@h:1234/d")) + must.NoError(t, cfg.LoadConnectionDetailsFromURL("postgres://u:p@h:1234/d")) - assert.Equal(t, "u", cfg.ReadConnection.Username) - assert.Equal(t, "p", cfg.ReadConnection.Password) - assert.Equal(t, "h", cfg.ReadConnection.Host) - assert.Equal(t, uint16(1234), cfg.ReadConnection.Port) - assert.Equal(t, "d", cfg.ReadConnection.Database) + test.EqOp(t, "u", cfg.ReadConnection.Username) + test.EqOp(t, "p", cfg.ReadConnection.Password) + test.EqOp(t, "h", cfg.ReadConnection.Host) + test.EqOp(t, uint16(1234), cfg.ReadConnection.Port) + test.EqOp(t, "d", cfg.ReadConnection.Database) }) T.Run("with invalid URL", func(t *testing.T) { t.Parallel() cfg := &Config{} - assert.Error(t, cfg.LoadConnectionDetailsFromURL("://bad")) + test.Error(t, cfg.LoadConnectionDetailsFromURL("://bad")) }) } @@ -336,7 +339,7 @@ func TestConnectionDetails_String(T *testing.T) { } expected := "user=admin password=secret database=mydb host=dbhost port=5432" - assert.Equal(t, expected, d.String()) + test.EqOp(t, expected, d.String()) }) } @@ -355,7 +358,7 @@ func TestConnectionDetails_URI(T *testing.T) { } expected := "postgres://admin:secret@dbhost:5432/mydb?sslmode=disable" - assert.Equal(t, expected, d.URI()) + test.EqOp(t, expected, d.URI()) }) } @@ -373,14 +376,14 @@ func TestConnectionDetails_ValidateWithContext(T *testing.T) { Port: 5432, } - assert.NoError(t, d.ValidateWithContext(t.Context())) + test.NoError(t, d.ValidateWithContext(t.Context())) }) T.Run("missing fields", func(t *testing.T) { t.Parallel() d := &ConnectionDetails{} - assert.Error(t, d.ValidateWithContext(t.Context())) + test.Error(t, d.ValidateWithContext(t.Context())) }) } @@ -399,7 +402,7 @@ func TestConnectionDetails_MySQLDSN(T *testing.T) { } expected := "admin:secret@tcp(dbhost:3306)/mydb" - assert.Equal(t, expected, d.MySQLDSN()) + test.EqOp(t, expected, d.MySQLDSN()) }) } @@ -413,7 +416,7 @@ func TestConnectionDetails_SQLiteDSN(T *testing.T) { Database: "/tmp/test.db", } - assert.Equal(t, "/tmp/test.db", d.SQLiteDSN()) + test.EqOp(t, "/tmp/test.db", d.SQLiteDSN()) }) T.Run("memory", func(t *testing.T) { @@ -423,7 +426,7 @@ func TestConnectionDetails_SQLiteDSN(T *testing.T) { Database: ":memory:", } - assert.Equal(t, ":memory:", d.SQLiteDSN()) + test.EqOp(t, ":memory:", d.SQLiteDSN()) }) } @@ -445,7 +448,7 @@ func TestConfig_GetReadConnectionString_ProviderAware(T *testing.T) { } expected := "user=user password=pass database=db host=localhost port=5432" - assert.Equal(t, expected, cfg.GetReadConnectionString()) + test.EqOp(t, expected, cfg.GetReadConnectionString()) }) T.Run("mysql provider", func(t *testing.T) { @@ -463,7 +466,7 @@ func TestConfig_GetReadConnectionString_ProviderAware(T *testing.T) { } expected := "user:pass@tcp(localhost:3306)/db" - assert.Equal(t, expected, cfg.GetReadConnectionString()) + test.EqOp(t, expected, cfg.GetReadConnectionString()) }) T.Run("sqlite provider", func(t *testing.T) { @@ -476,7 +479,7 @@ func TestConfig_GetReadConnectionString_ProviderAware(T *testing.T) { }, } - assert.Equal(t, "/tmp/test.db", cfg.GetReadConnectionString()) + test.EqOp(t, "/tmp/test.db", cfg.GetReadConnectionString()) }) } @@ -498,7 +501,7 @@ func TestConfig_GetWriteConnectionString_ProviderAware(T *testing.T) { } expected := "writer:secret@tcp(writehost:3306)/mydb" - assert.Equal(t, expected, cfg.GetWriteConnectionString()) + test.EqOp(t, expected, cfg.GetWriteConnectionString()) }) T.Run("sqlite provider", func(t *testing.T) { @@ -511,7 +514,7 @@ func TestConfig_GetWriteConnectionString_ProviderAware(T *testing.T) { }, } - assert.Equal(t, ":memory:", cfg.GetWriteConnectionString()) + test.EqOp(t, ":memory:", cfg.GetWriteConnectionString()) }) } @@ -521,25 +524,25 @@ func TestConfig_driverName(T *testing.T) { T.Run("postgres default", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderPostgres} - assert.Equal(t, "pgx", cfg.driverName()) + test.EqOp(t, "pgx", cfg.driverName()) }) T.Run("mysql", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderMySQL} - assert.Equal(t, "mysql", cfg.driverName()) + test.EqOp(t, "mysql", cfg.driverName()) }) T.Run("sqlite", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderSQLite} - assert.Equal(t, "sqlite", cfg.driverName()) + test.EqOp(t, "sqlite", cfg.driverName()) }) T.Run("unknown falls back to pgx", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: "unknown"} - assert.Equal(t, "pgx", cfg.driverName()) + test.EqOp(t, "pgx", cfg.driverName()) }) } @@ -556,9 +559,9 @@ func TestConfig_ConnectToReadDatabase(T *testing.T) { } db, err := cfg.ConnectToReadDatabase() - require.NoError(t, err) - require.NotNil(t, db) - require.NoError(t, db.Close()) + must.NoError(t, err) + must.NotNil(t, db) + must.NoError(t, db.Close()) }) T.Run("postgres lazy open", func(t *testing.T) { @@ -575,9 +578,9 @@ func TestConfig_ConnectToReadDatabase(T *testing.T) { } db, err := cfg.ConnectToReadDatabase() - require.NoError(t, err) - require.NotNil(t, db) - require.NoError(t, db.Close()) + must.NoError(t, err) + must.NotNil(t, db) + must.NoError(t, db.Close()) }) T.Run("mysql with bogus DSN returns error", func(t *testing.T) { @@ -586,8 +589,8 @@ func TestConfig_ConnectToReadDatabase(T *testing.T) { Provider: ProviderMySQL, } db, err := cfg.connectToDatabase("not a valid mysql dsn") - assert.Nil(t, db) - assert.Error(t, err) + test.Nil(t, db) + test.Error(t, err) }) } @@ -604,9 +607,9 @@ func TestConfig_ConnectToWriteDatabase(T *testing.T) { } db, err := cfg.ConnectToWriteDatabase() - require.NoError(t, err) - require.NotNil(t, db) - require.NoError(t, db.Close()) + must.NoError(t, err) + must.NotNil(t, db) + must.NoError(t, db.Close()) }) } @@ -621,9 +624,9 @@ func TestProvideDatabase(T *testing.T) { } client, err := ProvideDatabase(t.Context(), nil, nil, cfg, nil, nil) - assert.Nil(t, client) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid database provider") + test.Nil(t, client) + test.Error(t, err) + test.StrContains(t, err.Error(), "invalid database provider") }) T.Run("postgres lazy open", func(t *testing.T) { @@ -648,8 +651,8 @@ func TestProvideDatabase(T *testing.T) { } client, err := ProvideDatabase(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg, nil, nil) - require.NoError(t, err) - require.NotNil(t, client) + must.NoError(t, err) + must.NotNil(t, client) }) T.Run("mysql lazy open", func(t *testing.T) { @@ -674,8 +677,8 @@ func TestProvideDatabase(T *testing.T) { } client, err := ProvideDatabase(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg, nil, nil) - require.NoError(t, err) - require.NotNil(t, client) + must.NoError(t, err) + must.NotNil(t, client) }) T.Run("sqlite in-memory", func(t *testing.T) { @@ -692,8 +695,8 @@ func TestProvideDatabase(T *testing.T) { } client, err := ProvideDatabase(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg, nil, nil) - require.NoError(t, err) - require.NotNil(t, client) + must.NoError(t, err) + must.NotNil(t, client) }) T.Run("sqlite with enable database metrics and nil metrics provider", func(t *testing.T) { @@ -711,8 +714,8 @@ func TestProvideDatabase(T *testing.T) { } client, err := ProvideDatabase(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg, nil, nil) - require.NoError(t, err) - require.NotNil(t, client) + must.NoError(t, err) + must.NotNil(t, client) }) T.Run("sqlite with enable database metrics and noop metrics provider", func(t *testing.T) { @@ -730,8 +733,8 @@ func TestProvideDatabase(T *testing.T) { } client, err := ProvideDatabase(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg, nil, metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - require.NotNil(t, client) + must.NoError(t, err) + must.NotNil(t, client) }) T.Run("sqlite with migrations", func(t *testing.T) { @@ -750,9 +753,9 @@ func TestProvideDatabase(T *testing.T) { migrator := &stubMigrator{} client, err := ProvideDatabase(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg, migrator, nil) - require.NoError(t, err) - require.NotNil(t, client) - assert.True(t, migrator.called) + must.NoError(t, err) + must.NotNil(t, client) + test.True(t, migrator.called) }) T.Run("sqlite with bad path returns error", func(t *testing.T) { @@ -769,8 +772,8 @@ func TestProvideDatabase(T *testing.T) { } client, err := ProvideDatabase(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg, nil, nil) - assert.Nil(t, client) - assert.Error(t, err) + test.Nil(t, client) + test.Error(t, err) }) T.Run("sqlite with migrations error", func(t *testing.T) { @@ -787,10 +790,10 @@ func TestProvideDatabase(T *testing.T) { }, } - migrator := &stubMigrator{err: assert.AnError} + migrator := &stubMigrator{err: errStubMigrator} client, err := ProvideDatabase(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), cfg, migrator, nil) - assert.Nil(t, client) - assert.Error(t, err) - assert.True(t, migrator.called) + test.Nil(t, client) + test.Error(t, err) + test.True(t, migrator.called) }) } diff --git a/database/config/do_test.go b/database/config/do_test.go index 6d074ae..f3676c9 100644 --- a/database/config/do_test.go +++ b/database/config/do_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterClientConfig(T *testing.T) { @@ -31,8 +31,8 @@ func TestRegisterClientConfig(T *testing.T) { RegisterClientConfig(i) cc, err := do.Invoke[database.ClientConfig](i) - require.NoError(t, err) - assert.NotNil(t, cc) + must.NoError(t, err) + test.NotNil(t, cc) }) } @@ -61,12 +61,12 @@ func TestRegisterDatabase(T *testing.T) { RegisterDatabase(i) client, err := do.Invoke[database.Client](i) - require.NoError(t, err) - assert.NotNil(t, client) + must.NoError(t, err) + test.NotNil(t, client) cc, err := do.Invoke[database.ClientConfig](i) - require.NoError(t, err) - assert.NotNil(t, cc) + must.NoError(t, err) + test.NotNil(t, cc) }) } @@ -80,6 +80,6 @@ func TestProvideClientConfig(T *testing.T) { Provider: ProviderPostgres, } cc := ProvideClientConfig(cfg) - require.NotNil(t, cc) + must.NotNil(t, cc) }) } diff --git a/database/errors_test.go b/database/errors_test.go index b86639e..3815128 100644 --- a/database/errors_test.go +++ b/database/errors_test.go @@ -3,7 +3,7 @@ package database import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestErrUserAlreadyExists(T *testing.T) { @@ -12,8 +12,8 @@ func TestErrUserAlreadyExists(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrUserAlreadyExists) - assert.Contains(t, ErrUserAlreadyExists.Error(), "user already exists") + test.NotNil(t, ErrUserAlreadyExists) + test.StrContains(t, ErrUserAlreadyExists.Error(), "user already exists") }) } @@ -23,7 +23,7 @@ func TestErrDatabaseNotReady(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrDatabaseNotReady) - assert.Contains(t, ErrDatabaseNotReady.Error(), "database is not ready yet") + test.NotNil(t, ErrDatabaseNotReady) + test.StrContains(t, ErrDatabaseNotReady.Error(), "database is not ready yet") }) } diff --git a/database/filtering/query_filter_test.go b/database/filtering/query_filter_test.go index e348035..1ad6be3 100644 --- a/database/filtering/query_filter_test.go +++ b/database/filtering/query_filter_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" textsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/text" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestDefaultQueryFilter(T *testing.T) { @@ -22,11 +22,11 @@ func TestDefaultQueryFilter(T *testing.T) { qf := DefaultQueryFilter() - require.NotNil(t, qf) - require.NotNil(t, qf.MaxResponseSize) - assert.Equal(t, uint8(DefaultQueryFilterLimit), *qf.MaxResponseSize) - require.NotNil(t, qf.SortBy) - assert.Equal(t, SortAscending, qf.SortBy) + must.NotNil(t, qf) + must.NotNil(t, qf.MaxResponseSize) + test.EqOp(t, uint8(DefaultQueryFilterLimit), *qf.MaxResponseSize) + must.NotNil(t, qf.SortBy) + test.EqOp(t, SortAscending, qf.SortBy) }) } @@ -49,7 +49,7 @@ func TestQueryFilter_AttachToLogger(T *testing.T) { IncludeArchived: new(true), } - assert.NotNil(t, qf.AttachToLogger(logger)) + test.NotNil(t, qf.AttachToLogger(logger)) }) T.Run("with nil", func(t *testing.T) { @@ -57,7 +57,7 @@ func TestQueryFilter_AttachToLogger(T *testing.T) { logger := logging.NewNoopLogger() - assert.NotNil(t, (*QueryFilter)(nil).AttachToLogger(logger)) + test.NotNil(t, (*QueryFilter)(nil).AttachToLogger(logger)) }) } @@ -68,7 +68,7 @@ func TestQueryFilter_FromParams(T *testing.T) { t.Parallel() tt, err := time.Parse(time.RFC3339Nano, time.Now().UTC().Truncate(time.Second).Format(time.RFC3339Nano)) - require.NoError(t, err) + must.NoError(t, err) actual := &QueryFilter{} expected := &QueryFilter{ @@ -96,12 +96,12 @@ func TestQueryFilter_FromParams(T *testing.T) { actual.FromParams(exampleInput) - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) exampleInput[QueryKeySortBy] = []string{*SortAscending} actual.FromParams(exampleInput) - assert.Equal(t, SortAscending, actual.SortBy) + test.EqOp(t, SortAscending, actual.SortBy) }) } @@ -115,7 +115,7 @@ func TestQueryFilter_SetCursor(T *testing.T) { qf := &QueryFilter{} qf.SetCursor(&expected) - assert.Equal(t, expected, *qf.Cursor) + test.EqOp(t, expected, *qf.Cursor) }) T.Run("with nil", func(t *testing.T) { @@ -125,7 +125,7 @@ func TestQueryFilter_SetCursor(T *testing.T) { qf := &QueryFilter{Cursor: &original} qf.SetCursor(nil) - assert.Equal(t, original, *qf.Cursor) + test.EqOp(t, original, *qf.Cursor) }) } @@ -136,7 +136,7 @@ func TestQueryFilter_ToValues(T *testing.T) { t.Parallel() tt, err := time.Parse(time.RFC3339Nano, time.Now().UTC().Truncate(time.Second).Format(time.RFC3339Nano)) - require.NoError(t, err) + must.NoError(t, err) qf := &QueryFilter{ Cursor: new(t.Name()), @@ -161,7 +161,7 @@ func TestQueryFilter_ToValues(T *testing.T) { } actual := qf.ToValues() - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) }) T.Run("with nil", func(t *testing.T) { @@ -169,7 +169,7 @@ func TestQueryFilter_ToValues(T *testing.T) { qf := (*QueryFilter)(nil) expected := DefaultQueryFilter().ToValues() actual := qf.ToValues() - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) }) } @@ -182,7 +182,7 @@ func TestExtractQueryFilter(T *testing.T) { ctx := t.Context() tt, err := time.Parse(time.RFC3339Nano, time.Now().UTC().Truncate(time.Second).Format(time.RFC3339Nano)) - require.NoError(t, err) + must.NoError(t, err) expected := &QueryFilter{ Cursor: new(t.Name()), @@ -205,12 +205,12 @@ func TestExtractQueryFilter(T *testing.T) { } req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://verygoodsoftwarenotvirus.ru", http.NoBody) - assert.NoError(t, err) - require.NotNil(t, req) + test.NoError(t, err) + must.NotNil(t, req) req.URL.RawQuery = exampleInput.Encode() actual := ExtractQueryFilterFromRequest(req) - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) }) T.Run("with missing values", func(t *testing.T) { @@ -229,12 +229,12 @@ func TestExtractQueryFilter(T *testing.T) { } req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://verygoodsoftwarenotvirus.ru", http.NoBody) - assert.NoError(t, err) - require.NotNil(t, req) + test.NoError(t, err) + must.NotNil(t, req) req.URL.RawQuery = exampleInput.Encode() actual := ExtractQueryFilterFromRequest(req) - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) }) } @@ -255,7 +255,7 @@ func TestQueryFilter_ToPagination(T *testing.T) { } actual := qf.ToPagination() - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) }) T.Run("with nil value", func(t *testing.T) { @@ -264,7 +264,7 @@ func TestQueryFilter_ToPagination(T *testing.T) { qf := (*QueryFilter)(nil) actual := qf.ToPagination() - assert.NotNil(t, actual) + test.NotNil(t, actual) }) } @@ -297,7 +297,7 @@ func TestNewQueryFilteredResult(T *testing.T) { } actual := NewQueryFilteredResult(data, filteredCount, totalCount, idExtractor, qf) - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) }) T.Run("with empty data", func(t *testing.T) { @@ -326,7 +326,7 @@ func TestNewQueryFilteredResult(T *testing.T) { } actual := NewQueryFilteredResult(data, filteredCount, totalCount, idExtractor, qf) - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) }) T.Run("with no cursor", func(t *testing.T) { @@ -354,6 +354,6 @@ func TestNewQueryFilteredResult(T *testing.T) { } actual := NewQueryFilteredResult(data, filteredCount, totalCount, idExtractor, qf) - assert.Equal(t, expected, actual) + test.Eq(t, expected, actual) }) } diff --git a/database/mysql/do_test.go b/database/mysql/do_test.go index 8104613..727a514 100644 --- a/database/mysql/do_test.go +++ b/database/mysql/do_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterDatabaseClient(T *testing.T) { @@ -33,7 +33,7 @@ func TestRegisterDatabaseClient(T *testing.T) { RegisterDatabaseClient(i) client, err := do.Invoke[database.Client](i) - require.NoError(t, err) - assert.NotNil(t, client) + must.NoError(t, err) + test.NotNil(t, client) }) } diff --git a/database/mysql/mysql_test.go b/database/mysql/mysql_test.go index 1f5f3bf..a59894e 100644 --- a/database/mysql/mysql_test.go +++ b/database/mysql/mysql_test.go @@ -12,8 +12,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) // testClientConfig is a test implementation of database.ClientConfig. @@ -65,7 +65,7 @@ func buildTestClient(t *testing.T) (*Client, sqlmock.Sqlmock) { t.Helper() fakeDB, sqlMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: fakeDB, @@ -97,7 +97,7 @@ func TestQuerier_IsReady(T *testing.T) { // same DB for read/write, so only one ping db.ExpectPing().WillDelayFor(0) - assert.True(t, c.IsReady(ctx)) + test.True(t, c.IsReady(ctx)) }) T.Run("with read DB ping error", func(t *testing.T) { @@ -109,7 +109,7 @@ func TestQuerier_IsReady(T *testing.T) { db.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) T.Run("with write DB ping error", func(t *testing.T) { @@ -118,10 +118,10 @@ func TestQuerier_IsReady(T *testing.T) { ctx := t.Context() readDB, readMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -134,7 +134,7 @@ func TestQuerier_IsReady(T *testing.T) { readMock.ExpectPing().WillDelayFor(0) writeMock.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) T.Run("exhausting all available queries", func(t *testing.T) { @@ -148,7 +148,7 @@ func TestQuerier_IsReady(T *testing.T) { db.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) } @@ -166,8 +166,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with no connection strings", func(t *testing.T) { @@ -178,8 +178,8 @@ func TestProvideDatabaseClient(T *testing.T) { exampleConfig := &testClientConfig{} actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.Nil(t, actual) - assert.Error(t, err) + test.Nil(t, actual) + test.Error(t, err) }) T.Run("with only read connection string", func(t *testing.T) { @@ -193,8 +193,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with only write connection string", func(t *testing.T) { @@ -208,8 +208,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with metrics provider", func(t *testing.T) { @@ -223,8 +223,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, metrics.NewNoopMetricsProvider()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with metrics provider and single connection", func(t *testing.T) { @@ -238,8 +238,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, metrics.NewNoopMetricsProvider()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) } @@ -249,7 +249,7 @@ func TestDefaultTimeFunc(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotZero(t, defaultTimeFunc()) + test.False(t, defaultTimeFunc().IsZero()) }) } @@ -261,7 +261,7 @@ func TestQuerier_currentTime(T *testing.T) { c, _ := buildTestClient(t) - assert.NotEmpty(t, c.CurrentTime()) + test.False(t, c.CurrentTime().IsZero()) }) T.Run("handles nil", func(t *testing.T) { @@ -269,7 +269,7 @@ func TestQuerier_currentTime(T *testing.T) { var c *Client - assert.NotEmpty(t, c.CurrentTime()) + test.False(t, c.CurrentTime().IsZero()) }) } @@ -286,7 +286,7 @@ func TestQuerier_rollbackTransaction(T *testing.T) { db.ExpectRollback().WillReturnError(errors.New("blah")) tx, err := c.writeDB.BeginTx(ctx, nil) - require.NoError(t, err) + must.NoError(t, err) c.RollbackTransaction(ctx, tx) }) @@ -301,7 +301,7 @@ func TestQuerier_rollbackTransaction(T *testing.T) { db.ExpectRollback() tx, err := c.writeDB.BeginTx(ctx, nil) - require.NoError(t, err) + must.NoError(t, err) c.RollbackTransaction(ctx, tx) }) @@ -315,7 +315,7 @@ func TestClient_ReadDB(T *testing.T) { c, _ := buildTestClient(t) - assert.NotNil(t, c.ReadDB()) + test.NotNil(t, c.ReadDB()) }) } @@ -327,7 +327,7 @@ func TestClient_WriteDB(T *testing.T) { c, _ := buildTestClient(t) - assert.NotNil(t, c.WriteDB()) + test.NotNil(t, c.WriteDB()) }) } @@ -341,17 +341,17 @@ func TestClient_Close(T *testing.T) { db.ExpectClose() - assert.NoError(t, c.Close()) + test.NoError(t, c.Close()) }) T.Run("with separate read and write DBs", func(t *testing.T) { t.Parallel() readDB, readMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -363,7 +363,7 @@ func TestClient_Close(T *testing.T) { readMock.ExpectClose() writeMock.ExpectClose() - assert.NoError(t, c.Close()) + test.NoError(t, c.Close()) }) T.Run("with read close error", func(t *testing.T) { @@ -373,17 +373,17 @@ func TestClient_Close(T *testing.T) { db.ExpectClose().WillReturnError(errors.New("blah")) - assert.Error(t, c.Close()) + test.Error(t, c.Close()) }) T.Run("with write close error", func(t *testing.T) { t.Parallel() readDB, readMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -395,6 +395,6 @@ func TestClient_Close(T *testing.T) { readMock.ExpectClose() writeMock.ExpectClose().WillReturnError(errors.New("blah")) - assert.Error(t, c.Close()) + test.Error(t, c.Close()) }) } diff --git a/database/mysql/tableaccess/access_manager_test.go b/database/mysql/tableaccess/access_manager_test.go index 134d809..9ed2cbb 100644 --- a/database/mysql/tableaccess/access_manager_test.go +++ b/database/mysql/tableaccess/access_manager_test.go @@ -13,8 +13,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/retry" _ "github.com/go-sql-driver/mysql" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "github.com/testcontainers/testcontainers-go" mysqlcontainers "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/wait" @@ -80,8 +80,8 @@ func buildDatabaseConnectionForTest(t *testing.T, ctx context.Context) (*sql.DB, ) return containerErr }) - require.NoError(t, err) - require.NotNil(t, container) + must.NoError(t, err) + must.NotNil(t, container) // Connect as root for admin operations (CREATE USER, GRANT, etc.). // WithDefaultCredentials sets MYSQL_ROOT_PASSWORD to the same value as MYSQL_PASSWORD. @@ -89,9 +89,9 @@ func buildDatabaseConnectionForTest(t *testing.T, ctx context.Context) (*sql.DB, // Replace the non-root user with root in the DSN. connStr = "root:" + dbPassword + "@" + connStr[strings.Index(connStr, "@")+1:] db, err := sql.Open("mysql", connStr) - require.NoError(t, err) + must.NoError(t, err) - require.NoError(t, db.PingContext(ctx)) + must.NoError(t, db.PingContext(ctx)) return db, container } @@ -140,7 +140,7 @@ func TestQuoteIdent(T *testing.T) { T.Run(tt.name, func(t *testing.T) { t.Parallel() result := quoteIdent(tt.input) - assert.Equal(t, tt.expected, result) + test.EqOp(t, tt.expected, result) }) } } @@ -184,7 +184,7 @@ func TestQuoteLiteral(T *testing.T) { T.Run(tt.name, func(t *testing.T) { t.Parallel() result := quoteLiteral(tt.input) - assert.Equal(t, tt.expected, result) + test.EqOp(t, tt.expected, result) }) } } @@ -206,13 +206,13 @@ func TestIsValidPrivilege(T *testing.T) { } for _, p := range validPrivileges { - assert.True(t, isValidPrivilege(p), "expected %q to be valid", p) + test.True(t, isValidPrivilege(p), test.Sprintf("expected %q to be valid", p)) } }) T.Run("invalid privilege", func(t *testing.T) { t.Parallel() - assert.False(t, isValidPrivilege("INVALID")) + test.False(t, isValidPrivilege("INVALID")) }) } @@ -224,8 +224,8 @@ func TestManager_GrantUserAccessToTable_InvalidPrivilege(T *testing.T) { m := NewManager(nil) err := m.GrantUserAccessToTable(t.Context(), "user", "schema", "table", "INVALID") - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid privilege") + test.Error(t, err) + test.StrContains(t, err.Error(), "invalid privilege") }) } @@ -236,7 +236,7 @@ func TestNewManager(T *testing.T) { t.Parallel() m := NewManager(nil) - assert.NotNil(t, m) + test.NotNil(t, m) }) } @@ -260,11 +260,11 @@ func TestManager_CreateUser(T *testing.T) { password := "testpass123" err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("duplicate user", func(t *testing.T) { @@ -284,10 +284,10 @@ func TestManager_CreateUser(T *testing.T) { password := "testpass123" err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateUser(ctx, username, password) - assert.Error(t, err) + test.Error(t, err) }) } @@ -311,14 +311,14 @@ func TestManager_DeleteUser(T *testing.T) { password := "testpass123" err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.DeleteUser(ctx, username) - assert.NoError(t, err) + test.NoError(t, err) exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) T.Run("delete non-existent user", func(t *testing.T) { @@ -335,7 +335,7 @@ func TestManager_DeleteUser(T *testing.T) { mgr := NewManager(adminDB) err := mgr.DeleteUser(ctx, "nonexistentuser") - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -357,15 +357,15 @@ func TestManager_CreateDatabase(T *testing.T) { owner := "dbowner" err := mgr.CreateUser(ctx, owner, "pass") - require.NoError(t, err) + must.NoError(t, err) dbName := "testdb" err = mgr.CreateDatabase(ctx, dbName, owner) - assert.NoError(t, err) + test.NoError(t, err) exists, err := mgr.DatabaseExists(ctx, dbName) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) } @@ -387,18 +387,18 @@ func TestManager_DeleteDatabase(T *testing.T) { owner := "deldbowner" err := mgr.CreateUser(ctx, owner, "pass") - require.NoError(t, err) + must.NoError(t, err) dbName := "deldb" err = mgr.CreateDatabase(ctx, dbName, owner) - require.NoError(t, err) + must.NoError(t, err) err = mgr.DeleteDatabase(ctx, dbName) - assert.NoError(t, err) + test.NoError(t, err) exists, err := mgr.DatabaseExists(ctx, dbName) - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) } @@ -419,8 +419,8 @@ func TestManager_UserExists(T *testing.T) { mgr := NewManager(adminDB) exists, err := mgr.UserExists(ctx, "nonexistent_user_xyz") - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) } @@ -441,7 +441,7 @@ func TestManager_DatabaseExists(T *testing.T) { mgr := NewManager(adminDB) exists, err := mgr.DatabaseExists(ctx, "nonexistent_db_xyz") - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) } diff --git a/database/null_values_test.go b/database/null_values_test.go index fad0f12..85a4799 100644 --- a/database/null_values_test.go +++ b/database/null_values_test.go @@ -6,7 +6,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func Test_timeFromNullTime(T *testing.T) { @@ -18,7 +18,7 @@ func Test_timeFromNullTime(T *testing.T) { now := time.Now() nt := sql.NullTime{Time: now, Valid: true} - assert.Equal(t, now, TimeFromNullTime(nt)) + test.EqOp(t, now, TimeFromNullTime(nt)) }) T.Run("with invalid time", func(t *testing.T) { @@ -26,7 +26,7 @@ func Test_timeFromNullTime(T *testing.T) { nt := sql.NullTime{Valid: false} - assert.Zero(t, TimeFromNullTime(nt)) + test.True(t, TimeFromNullTime(nt).IsZero()) }) } @@ -39,7 +39,7 @@ func Test_timePointerFromNullTime(T *testing.T) { expected := time.Now() actual := TimePointerFromNullTime(sql.NullTime{Time: expected, Valid: true}) - assert.Equal(t, expected, *actual) + test.EqOp(t, expected, *actual) }) T.Run("with invalid value", func(t *testing.T) { @@ -47,7 +47,7 @@ func Test_timePointerFromNullTime(T *testing.T) { actual := TimePointerFromNullTime(sql.NullTime{Time: time.Now(), Valid: false}) - assert.Nil(t, actual) + test.Nil(t, actual) }) } @@ -60,8 +60,8 @@ func Test_stringPointerFromNullString(T *testing.T) { expected := t.Name() actual := StringPointerFromNullString(sql.NullString{String: expected, Valid: true}) - assert.NotNil(t, actual) - assert.Equal(t, expected, *actual) + test.NotNil(t, actual) + test.EqOp(t, expected, *actual) }) T.Run("with invalid value", func(t *testing.T) { @@ -69,7 +69,7 @@ func Test_stringPointerFromNullString(T *testing.T) { actual := StringPointerFromNullString(sql.NullString{String: t.Name(), Valid: false}) - assert.Nil(t, actual) + test.Nil(t, actual) }) } @@ -80,14 +80,14 @@ func Test_stringFromNullString(T *testing.T) { t.Parallel() input := sql.NullString{String: t.Name(), Valid: true} - assert.Equal(t, input.String, StringFromNullString(input)) + test.EqOp(t, input.String, StringFromNullString(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullString{} - assert.Empty(t, StringFromNullString(input)) + test.EqOp(t, "", StringFromNullString(input)) }) } @@ -98,7 +98,7 @@ func Test_nullStringFromString(T *testing.T) { t.Parallel() expected := sql.NullString{String: t.Name(), Valid: true} - assert.Equal(t, expected, NullStringFromString(t.Name())) + test.EqOp(t, expected, NullStringFromString(t.Name())) }) } @@ -109,14 +109,14 @@ func Test_nullStringFromStringPointer(T *testing.T) { t.Parallel() expected := sql.NullString{String: t.Name(), Valid: true} - assert.Equal(t, expected, NullStringFromStringPointer(new(t.Name()))) + test.EqOp(t, expected, NullStringFromStringPointer(new(t.Name()))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullString{String: ""} - assert.Equal(t, expected, NullStringFromStringPointer(nil)) + test.EqOp(t, expected, NullStringFromStringPointer(nil)) }) } @@ -128,7 +128,7 @@ func Test_nullTimeFromTime(T *testing.T) { exampleTime := time.Now() expected := sql.NullTime{Time: exampleTime, Valid: true} - assert.Equal(t, expected, NullTimeFromTime(exampleTime)) + test.EqOp(t, expected, NullTimeFromTime(exampleTime)) }) } @@ -140,14 +140,14 @@ func Test_nullTimeFromTimePointer(T *testing.T) { exampleTime := time.Now() expected := sql.NullTime{Time: exampleTime, Valid: true} - assert.Equal(t, expected, NullTimeFromTimePointer(new(exampleTime))) + test.EqOp(t, expected, NullTimeFromTimePointer(new(exampleTime))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullTime{} - assert.Equal(t, expected, NullTimeFromTimePointer(nil)) + test.EqOp(t, expected, NullTimeFromTimePointer(nil)) }) } @@ -158,14 +158,14 @@ func Test_nullInt32FromUint8Pointer(T *testing.T) { t.Parallel() expected := sql.NullInt32{Int32: 123, Valid: true} - assert.Equal(t, expected, NullInt32FromUint8Pointer(new(uint8(expected.Int32)))) + test.EqOp(t, expected, NullInt32FromUint8Pointer(new(uint8(expected.Int32)))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullInt32{} - assert.Equal(t, expected, NullInt32FromUint8Pointer(nil)) + test.EqOp(t, expected, NullInt32FromUint8Pointer(nil)) }) } @@ -176,14 +176,14 @@ func Test_nullInt32FromUint16Pointer(T *testing.T) { t.Parallel() expected := sql.NullInt32{Int32: 123, Valid: true} - assert.Equal(t, expected, NullInt32FromUint16Pointer(new(uint16(expected.Int32)))) + test.EqOp(t, expected, NullInt32FromUint16Pointer(new(uint16(expected.Int32)))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullInt32{} - assert.Equal(t, expected, NullInt32FromUint16Pointer(nil)) + test.EqOp(t, expected, NullInt32FromUint16Pointer(nil)) }) } @@ -194,7 +194,7 @@ func Test_nullInt32FromUint16(T *testing.T) { t.Parallel() expected := sql.NullInt32{Int32: 123, Valid: true} - assert.Equal(t, expected, NullInt32FromUint16(uint16(expected.Int32))) + test.EqOp(t, expected, NullInt32FromUint16(uint16(expected.Int32))) }) } @@ -205,7 +205,7 @@ func Test_nullBoolFromBool(T *testing.T) { t.Parallel() expected := sql.NullBool{Bool: true, Valid: true} - assert.Equal(t, expected, NullBoolFromBool(true)) + test.EqOp(t, expected, NullBoolFromBool(true)) }) } @@ -218,8 +218,8 @@ func Test_nullBoolFromBoolPointer(T *testing.T) { b := true result := NullBoolFromBoolPointer(&b) - assert.True(t, result.Valid) - assert.True(t, result.Bool) + test.True(t, result.Valid) + test.True(t, result.Bool) }) T.Run("with nil pointer", func(t *testing.T) { @@ -227,7 +227,7 @@ func Test_nullBoolFromBoolPointer(T *testing.T) { result := NullBoolFromBoolPointer(nil) - assert.False(t, result.Valid) + test.False(t, result.Valid) }) } @@ -238,14 +238,14 @@ func Test_boolFromNullBool(T *testing.T) { t.Parallel() input := sql.NullBool{Bool: true, Valid: true} - assert.Equal(t, input.Bool, BoolFromNullBool(input)) + test.EqOp(t, input.Bool, BoolFromNullBool(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullBool{Bool: true, Valid: false} - assert.False(t, BoolFromNullBool(input)) + test.False(t, BoolFromNullBool(input)) }) } @@ -256,14 +256,14 @@ func Test_nullInt32FromInt32Pointer(T *testing.T) { t.Parallel() expected := sql.NullInt32{Int32: 123, Valid: true} - assert.Equal(t, expected, NullInt32FromInt32Pointer(new(expected.Int32))) + test.EqOp(t, expected, NullInt32FromInt32Pointer(new(expected.Int32))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullInt32{} - assert.Equal(t, expected, NullInt32FromInt32Pointer(nil)) + test.EqOp(t, expected, NullInt32FromInt32Pointer(nil)) }) } @@ -274,14 +274,14 @@ func Test_nullInt32FromUint32Pointer(T *testing.T) { t.Parallel() expected := sql.NullInt32{Int32: 123, Valid: true} - assert.Equal(t, expected, NullInt32FromUint32Pointer(new(uint32(expected.Int32)))) + test.EqOp(t, expected, NullInt32FromUint32Pointer(new(uint32(expected.Int32)))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullInt32{} - assert.Equal(t, expected, NullInt32FromUint32Pointer(nil)) + test.EqOp(t, expected, NullInt32FromUint32Pointer(nil)) }) } @@ -292,14 +292,14 @@ func Test_int32PointerFromNullInt32(T *testing.T) { t.Parallel() input := sql.NullInt32{Int32: 123, Valid: true} - assert.Equal(t, new(input.Int32), Int32PointerFromNullInt32(input)) + test.Eq(t, new(input.Int32), Int32PointerFromNullInt32(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullInt32{Int32: 123, Valid: false} - assert.Nil(t, Int32PointerFromNullInt32(input)) + test.Nil(t, Int32PointerFromNullInt32(input)) }) } @@ -310,14 +310,14 @@ func Test_float32PointerFromNullString(T *testing.T) { t.Parallel() input := sql.NullString{String: "1.23", Valid: true} - assert.Equal(t, new(float32(1.23)), Float32PointerFromNullString(input)) + test.Eq(t, new(float32(1.23)), Float32PointerFromNullString(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullString{String: "1.23", Valid: false} - assert.Nil(t, Float32PointerFromNullString(input)) + test.Nil(t, Float32PointerFromNullString(input)) }) } @@ -328,14 +328,14 @@ func Test_float64PointerFromNullString(T *testing.T) { t.Parallel() input := sql.NullString{String: "1.23", Valid: true} - assert.Equal(t, new(1.23), Float64PointerFromNullString(input)) + test.Eq(t, new(1.23), Float64PointerFromNullString(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullString{String: "1.23", Valid: false} - assert.Nil(t, Float64PointerFromNullString(input)) + test.Nil(t, Float64PointerFromNullString(input)) }) } @@ -346,7 +346,7 @@ func Test_stringFromFloat32(T *testing.T) { t.Parallel() value := float32(1.23) - assert.Equal(t, "1.23", StringFromFloat32(value)) + test.EqOp(t, "1.23", StringFromFloat32(value)) }) } @@ -356,13 +356,13 @@ func Test_float32FromString(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.Equal(t, float32(1.23), Float32FromString("1.23")) + test.EqOp(t, float32(1.23), Float32FromString("1.23")) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() - assert.Zero(t, Float32FromString(t.Name())) + test.Zero(t, Float32FromString(t.Name())) }) } @@ -373,14 +373,14 @@ func Test_float32FromNullString(T *testing.T) { t.Parallel() input := sql.NullString{String: "1.23", Valid: true} - assert.Equal(t, float32(1.23), Float32FromNullString(input)) + test.EqOp(t, float32(1.23), Float32FromNullString(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullString{String: "1.23", Valid: false} - assert.Zero(t, Float32FromNullString(input)) + test.Zero(t, Float32FromNullString(input)) }) } @@ -392,14 +392,14 @@ func Test_nullStringFromFloat32Pointer(T *testing.T) { value := float32(1.23) expected := sql.NullString{String: fmt.Sprintf("%v", value), Valid: true} - assert.Equal(t, expected, NullStringFromFloat32Pointer(new(value))) + test.EqOp(t, expected, NullStringFromFloat32Pointer(new(value))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullString{} - assert.Equal(t, expected, NullStringFromFloat32Pointer(nil)) + test.EqOp(t, expected, NullStringFromFloat32Pointer(nil)) }) } @@ -411,7 +411,7 @@ func Test_nullStringFromFloat32(T *testing.T) { value := float32(1.23) expected := sql.NullString{String: fmt.Sprintf("%v", value), Valid: true} - assert.Equal(t, expected, NullStringFromFloat32(value)) + test.EqOp(t, expected, NullStringFromFloat32(value)) }) } @@ -422,7 +422,7 @@ func Test_stringFromFloat64(T *testing.T) { t.Parallel() value := float64(1.23) - assert.Equal(t, "1.23", StringFromFloat64(value)) + test.EqOp(t, "1.23", StringFromFloat64(value)) }) } @@ -434,14 +434,14 @@ func Test_nullStringFromFloat64Pointer(T *testing.T) { value := float64(1.23) expected := sql.NullString{String: fmt.Sprintf("%v", value), Valid: true} - assert.Equal(t, expected, NullStringFromFloat64Pointer(new(value))) + test.EqOp(t, expected, NullStringFromFloat64Pointer(new(value))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullString{} - assert.Equal(t, expected, NullStringFromFloat64Pointer(nil)) + test.EqOp(t, expected, NullStringFromFloat64Pointer(nil)) }) } @@ -452,14 +452,14 @@ func Test_nullInt64FromUint32Pointer(T *testing.T) { t.Parallel() expected := sql.NullInt64{Int64: 123, Valid: true} - assert.Equal(t, expected, NullInt64FromUint32Pointer(new(uint32(expected.Int64)))) + test.EqOp(t, expected, NullInt64FromUint32Pointer(new(uint32(expected.Int64)))) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() expected := sql.NullInt64{} - assert.Equal(t, expected, NullInt64FromUint32Pointer(nil)) + test.EqOp(t, expected, NullInt64FromUint32Pointer(nil)) }) } @@ -470,14 +470,14 @@ func Test_uint16PointerFromNullInt32(T *testing.T) { t.Parallel() input := sql.NullInt32{Int32: 123, Valid: true} - assert.Equal(t, new(uint16(input.Int32)), Uint16PointerFromNullInt32(input)) + test.Eq(t, new(uint16(input.Int32)), Uint16PointerFromNullInt32(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullInt32{Int32: 123, Valid: false} - assert.Nil(t, Uint16PointerFromNullInt32(input)) + test.Nil(t, Uint16PointerFromNullInt32(input)) }) } @@ -488,14 +488,14 @@ func Test_uint32PointerFromNullInt32(T *testing.T) { t.Parallel() input := sql.NullInt32{Int32: 123, Valid: true} - assert.Equal(t, new(uint32(input.Int32)), Uint32PointerFromNullInt32(input)) + test.Eq(t, new(uint32(input.Int32)), Uint32PointerFromNullInt32(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullInt32{Int32: 123, Valid: false} - assert.Nil(t, Uint32PointerFromNullInt32(input)) + test.Nil(t, Uint32PointerFromNullInt32(input)) }) } @@ -506,13 +506,13 @@ func Test_uint32PointerFromNullInt64(T *testing.T) { t.Parallel() input := sql.NullInt64{Int64: 123, Valid: true} - assert.Equal(t, new(uint32(input.Int64)), Uint32PointerFromNullInt64(input)) + test.Eq(t, new(uint32(input.Int64)), Uint32PointerFromNullInt64(input)) }) T.Run("with invalid value", func(t *testing.T) { t.Parallel() input := sql.NullInt64{Int64: 123, Valid: false} - assert.Nil(t, Uint32PointerFromNullInt64(input)) + test.Nil(t, Uint32PointerFromNullInt64(input)) }) } diff --git a/database/postgres/do_test.go b/database/postgres/do_test.go index c97024a..a273306 100644 --- a/database/postgres/do_test.go +++ b/database/postgres/do_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterDatabaseClient(T *testing.T) { @@ -33,7 +33,7 @@ func TestRegisterDatabaseClient(T *testing.T) { RegisterDatabaseClient(i) client, err := do.Invoke[database.Client](i) - require.NoError(t, err) - assert.NotNil(t, client) + must.NoError(t, err) + test.NotNil(t, client) }) } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 544d76c..13034ea 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -12,8 +12,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) // testClientConfig is a test implementation of database.ClientConfig. @@ -65,7 +65,7 @@ func buildTestClient(t *testing.T) (*Client, sqlmock.Sqlmock) { t.Helper() fakeDB, sqlMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: fakeDB, @@ -97,7 +97,7 @@ func TestQuerier_IsReady(T *testing.T) { // same DB for read/write, so only one ping db.ExpectPing().WillDelayFor(0) - assert.True(t, c.IsReady(ctx)) + test.True(t, c.IsReady(ctx)) }) T.Run("with read DB ping error", func(t *testing.T) { @@ -109,7 +109,7 @@ func TestQuerier_IsReady(T *testing.T) { db.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) T.Run("with write DB ping error", func(t *testing.T) { @@ -118,10 +118,10 @@ func TestQuerier_IsReady(T *testing.T) { ctx := t.Context() readDB, readMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -134,7 +134,7 @@ func TestQuerier_IsReady(T *testing.T) { readMock.ExpectPing().WillDelayFor(0) writeMock.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) T.Run("exhausting all available queries", func(t *testing.T) { @@ -148,7 +148,7 @@ func TestQuerier_IsReady(T *testing.T) { db.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) } @@ -166,8 +166,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with no connection strings", func(t *testing.T) { @@ -178,8 +178,8 @@ func TestProvideDatabaseClient(T *testing.T) { exampleConfig := &testClientConfig{} actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.Nil(t, actual) - assert.Error(t, err) + test.Nil(t, actual) + test.Error(t, err) }) T.Run("with only read connection string", func(t *testing.T) { @@ -193,8 +193,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with only write connection string", func(t *testing.T) { @@ -208,8 +208,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with metrics provider", func(t *testing.T) { @@ -223,8 +223,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, metrics.NewNoopMetricsProvider()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with metrics provider and single connection", func(t *testing.T) { @@ -238,8 +238,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, metrics.NewNoopMetricsProvider()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) } @@ -249,7 +249,7 @@ func TestDefaultTimeFunc(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotZero(t, defaultTimeFunc()) + test.False(t, defaultTimeFunc().IsZero()) }) } @@ -261,7 +261,7 @@ func TestQuerier_currentTime(T *testing.T) { c, _ := buildTestClient(t) - assert.NotEmpty(t, c.CurrentTime()) + test.False(t, c.CurrentTime().IsZero()) }) T.Run("handles nil", func(t *testing.T) { @@ -269,7 +269,7 @@ func TestQuerier_currentTime(T *testing.T) { var c *Client - assert.NotEmpty(t, c.CurrentTime()) + test.False(t, c.CurrentTime().IsZero()) }) } @@ -286,7 +286,7 @@ func TestQuerier_rollbackTransaction(T *testing.T) { db.ExpectRollback().WillReturnError(errors.New("blah")) tx, err := c.writeDB.BeginTx(ctx, nil) - require.NoError(t, err) + must.NoError(t, err) c.RollbackTransaction(ctx, tx) }) @@ -301,7 +301,7 @@ func TestQuerier_rollbackTransaction(T *testing.T) { db.ExpectRollback() tx, err := c.writeDB.BeginTx(ctx, nil) - require.NoError(t, err) + must.NoError(t, err) c.RollbackTransaction(ctx, tx) }) @@ -315,7 +315,7 @@ func TestClient_ReadDB(T *testing.T) { c, _ := buildTestClient(t) - assert.NotNil(t, c.ReadDB()) + test.NotNil(t, c.ReadDB()) }) } @@ -327,7 +327,7 @@ func TestClient_WriteDB(T *testing.T) { c, _ := buildTestClient(t) - assert.NotNil(t, c.WriteDB()) + test.NotNil(t, c.WriteDB()) }) } @@ -341,17 +341,17 @@ func TestClient_Close(T *testing.T) { db.ExpectClose() - assert.NoError(t, c.Close()) + test.NoError(t, c.Close()) }) T.Run("with separate read and write DBs", func(t *testing.T) { t.Parallel() readDB, readMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -363,7 +363,7 @@ func TestClient_Close(T *testing.T) { readMock.ExpectClose() writeMock.ExpectClose() - assert.NoError(t, c.Close()) + test.NoError(t, c.Close()) }) T.Run("with read close error", func(t *testing.T) { @@ -373,17 +373,17 @@ func TestClient_Close(T *testing.T) { db.ExpectClose().WillReturnError(errors.New("blah")) - assert.Error(t, c.Close()) + test.Error(t, c.Close()) }) T.Run("with write close error", func(t *testing.T) { t.Parallel() readDB, readMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -395,6 +395,6 @@ func TestClient_Close(T *testing.T) { readMock.ExpectClose() writeMock.ExpectClose().WillReturnError(errors.New("blah")) - assert.Error(t, c.Close()) + test.Error(t, c.Close()) }) } diff --git a/database/postgres/tableaccess/access_manager_test.go b/database/postgres/tableaccess/access_manager_test.go index 9bb4c6d..80f7b25 100644 --- a/database/postgres/tableaccess/access_manager_test.go +++ b/database/postgres/tableaccess/access_manager_test.go @@ -13,8 +13,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/retry" _ "github.com/jackc/pgx/v5/stdlib" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" @@ -70,10 +70,10 @@ func buildConnectionString(t *testing.T, container *postgres.PostgresContainer, ctx := t.Context() containerPort, err := container.MappedPort(ctx, "5432/tcp") - require.NoError(t, err) + must.NoError(t, err) host, err := container.Host(ctx) - require.NoError(t, err) + must.NoError(t, err) return fmt.Sprintf("postgres://%s:%s@%s/%s", username, password, net.JoinHostPort(host, containerPort.Port()), dbName) } @@ -97,11 +97,11 @@ func buildDatabaseConnectionForTest(t *testing.T, ctx context.Context) (*sql.DB, ) return containerErr }) - require.NoError(t, err) - require.NotNil(t, container) + must.NoError(t, err) + must.NotNil(t, container) db, err := sql.Open("pgx", container.MustConnectionString(ctx, "sslmode=disable")) - require.NoError(t, err) + must.NoError(t, err) return db, container } @@ -150,7 +150,7 @@ func TestQuoteIdent(T *testing.T) { T.Run(tt.name, func(t *testing.T) { t.Parallel() result := quoteIdent(tt.input) - assert.Equal(t, tt.expected, result) + test.EqOp(t, tt.expected, result) }) } } @@ -200,7 +200,7 @@ word'`, T.Run(tt.name, func(t *testing.T) { t.Parallel() result := quoteLiteral(tt.input) - assert.Equal(t, tt.expected, result) + test.EqOp(t, tt.expected, result) }) } } @@ -274,7 +274,7 @@ func TestIsValidPrivilege(T *testing.T) { T.Run(tt.name, func(t *testing.T) { t.Parallel() result := isValidPrivilege(tt.privilege) - assert.Equal(t, tt.expected, result) + test.EqOp(t, tt.expected, result) }) } } @@ -299,12 +299,12 @@ func TestManager_CreateUser(T *testing.T) { password := "testpass123" err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Verify user was created exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("duplicate user", func(t *testing.T) { @@ -325,11 +325,11 @@ func TestManager_CreateUser(T *testing.T) { // Create user first time err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Try to create same user again err = mgr.CreateUser(ctx, username, password) - assert.Error(t, err) + test.Error(t, err) }) T.Run("special characters in username", func(t *testing.T) { @@ -349,12 +349,12 @@ func TestManager_CreateUser(T *testing.T) { password := "testpass123" err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Verify user was created exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("special characters in password", func(t *testing.T) { @@ -374,12 +374,12 @@ func TestManager_CreateUser(T *testing.T) { password := `pass'word"with"quotes` err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Verify user was created exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) } @@ -404,21 +404,21 @@ func TestManager_DeleteUser(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Verify user exists exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) // Delete user err = mgr.DeleteUser(ctx, username) - assert.NoError(t, err) + test.NoError(t, err) // Verify user no longer exists exists, err = mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) T.Run("delete non-existent user", func(t *testing.T) { @@ -438,7 +438,7 @@ func TestManager_DeleteUser(T *testing.T) { // Delete non-existent user should not error due to IF EXISTS err := mgr.DeleteUser(ctx, username) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("special characters in username", func(t *testing.T) { @@ -459,16 +459,16 @@ func TestManager_DeleteUser(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Delete user err = mgr.DeleteUser(ctx, username) - assert.NoError(t, err) + test.NoError(t, err) // Verify user no longer exists exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) } @@ -493,12 +493,12 @@ func TestManager_UserExists(T *testing.T) { // Create user err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Check if user exists exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("non-existing user", func(t *testing.T) { @@ -518,8 +518,8 @@ func TestManager_UserExists(T *testing.T) { // Check if user exists exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) T.Run("special characters in username", func(t *testing.T) { @@ -540,12 +540,12 @@ func TestManager_UserExists(T *testing.T) { // Create user err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Check if user exists exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) } @@ -571,16 +571,16 @@ func TestManager_CreateDatabase(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Create database err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Verify database was created exists, err := mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("duplicate database", func(t *testing.T) { @@ -602,15 +602,15 @@ func TestManager_CreateDatabase(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Create database first time err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Try to create same database again err = mgr.CreateDatabase(ctx, databaseName, username) - assert.Error(t, err) + test.Error(t, err) }) T.Run("special characters in database name", func(t *testing.T) { @@ -632,16 +632,16 @@ func TestManager_CreateDatabase(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Create database err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Verify database was created exists, err := mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("special characters in owner name", func(t *testing.T) { @@ -663,16 +663,16 @@ func TestManager_CreateDatabase(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Create database err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Verify database was created exists, err := mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) } @@ -698,25 +698,25 @@ func TestManager_DeleteDatabase(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Create database err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Verify database exists exists, err := mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) // Delete database err = mgr.DeleteDatabase(ctx, databaseName) - assert.NoError(t, err) + test.NoError(t, err) // Verify database no longer exists exists, err = mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) T.Run("delete non-existent database", func(t *testing.T) { @@ -736,7 +736,7 @@ func TestManager_DeleteDatabase(T *testing.T) { // Delete non-existent database should not error due to IF EXISTS err := mgr.DeleteDatabase(ctx, databaseName) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("special characters in database name", func(t *testing.T) { @@ -758,20 +758,20 @@ func TestManager_DeleteDatabase(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Create database err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Delete database err = mgr.DeleteDatabase(ctx, databaseName) - assert.NoError(t, err) + test.NoError(t, err) // Verify database no longer exists exists, err := mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) } @@ -797,16 +797,16 @@ func TestManager_DatabaseExists(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Create database err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Check if database exists exists, err := mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("non-existing database", func(t *testing.T) { @@ -826,8 +826,8 @@ func TestManager_DatabaseExists(T *testing.T) { // Check if database exists exists, err := mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.False(t, exists) + test.NoError(t, err) + test.False(t, exists) }) T.Run("special characters in database name", func(t *testing.T) { @@ -849,16 +849,16 @@ func TestManager_DatabaseExists(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Create database err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Check if database exists exists, err := mgr.DatabaseExists(ctx, databaseName) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) } @@ -884,15 +884,15 @@ func TestManager_UserCanAccessDatabase(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Check access canAccess, err := mgr.UserCanAccessDatabase(ctx, username, databaseName) - assert.NoError(t, err) - assert.True(t, canAccess) + test.NoError(t, err) + test.True(t, canAccess) }) T.Run("user does not have access", func(t *testing.T) { @@ -915,22 +915,22 @@ func TestManager_UserCanAccessDatabase(T *testing.T) { // Create user and database with different owner err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateUser(ctx, ownerUsername, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, ownerUsername) - assert.NoError(t, err) + test.NoError(t, err) // Grant CONNECT privilege to the user for the database _, err = adminDB.ExecContext(ctx, fmt.Sprintf("GRANT CONNECT ON DATABASE %s TO %s", quoteIdent(databaseName), quoteIdent(username))) - assert.NoError(t, err) + test.NoError(t, err) // Check access - user should have access now canAccess, err := mgr.UserCanAccessDatabase(ctx, username, databaseName) - assert.NoError(t, err) - assert.True(t, canAccess) + test.NoError(t, err) + test.True(t, canAccess) }) T.Run("non-existent user", func(t *testing.T) { @@ -951,8 +951,8 @@ func TestManager_UserCanAccessDatabase(T *testing.T) { // Check access for non-existent user - should return error canAccess, err := mgr.UserCanAccessDatabase(ctx, username, databaseName) - assert.Error(t, err) - assert.False(t, canAccess) + test.Error(t, err) + test.False(t, canAccess) }) T.Run("non-existent database", func(t *testing.T) { @@ -974,12 +974,12 @@ func TestManager_UserCanAccessDatabase(T *testing.T) { // Create user err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Check access to non-existent database - should return error canAccess, err := mgr.UserCanAccessDatabase(ctx, username, databaseName) - assert.Error(t, err) - assert.False(t, canAccess) + test.Error(t, err) + test.False(t, canAccess) }) T.Run("special characters in usernames and database names", func(t *testing.T) { @@ -1001,15 +1001,15 @@ func TestManager_UserCanAccessDatabase(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Check access canAccess, err := mgr.UserCanAccessDatabase(ctx, username, databaseName) - assert.NoError(t, err) - assert.True(t, canAccess) + test.NoError(t, err) + test.True(t, canAccess) }) } @@ -1037,18 +1037,18 @@ func TestManager_GrantUserAccessToTable(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Create a test table _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(schema), quoteIdent(table))) - assert.NoError(t, err) + test.NoError(t, err) // Grant access err = mgr.GrantUserAccessToTable(ctx, username, schema, table, "SELECT") - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("all valid privileges", func(t *testing.T) { @@ -1072,19 +1072,19 @@ func TestManager_GrantUserAccessToTable(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Create a test table _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(schema), quoteIdent(table))) - assert.NoError(t, err) + test.NoError(t, err) privileges := []string{"SELECT", "INSERT", "UPDATE", "DELETE", "TRUNCATE", "REFERENCES", "TRIGGER"} for _, privilege := range privileges { err = mgr.GrantUserAccessToTable(ctx, username, schema, table, privilege) - assert.NoError(t, err, "Failed to grant %s privilege", privilege) + test.NoError(t, err, test.Sprintf("Failed to grant %s privilege", privilege)) } }) @@ -1109,19 +1109,19 @@ func TestManager_GrantUserAccessToTable(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Create a test table _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(schema), quoteIdent(table))) - assert.NoError(t, err) + test.NoError(t, err) // Try to grant invalid privilege err = mgr.GrantUserAccessToTable(ctx, username, schema, table, "INVALID_PRIVILEGE") - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid privilege") + test.Error(t, err) + test.StrContains(t, err.Error(), "invalid privilege") }) T.Run("case sensitive privilege", func(t *testing.T) { @@ -1145,19 +1145,19 @@ func TestManager_GrantUserAccessToTable(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Create a test table _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(schema), quoteIdent(table))) - assert.NoError(t, err) + test.NoError(t, err) // Try to grant lowercase privilege (should fail) err = mgr.GrantUserAccessToTable(ctx, username, schema, table, "select") - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid privilege") + test.Error(t, err) + test.StrContains(t, err.Error(), "invalid privilege") }) T.Run("special characters in identifiers", func(t *testing.T) { @@ -1181,21 +1181,21 @@ func TestManager_GrantUserAccessToTable(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Create schema and table _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA %s", quoteIdent(schema))) - assert.NoError(t, err) + test.NoError(t, err) _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(schema), quoteIdent(table))) - assert.NoError(t, err) + test.NoError(t, err) // Grant access err = mgr.GrantUserAccessToTable(ctx, username, schema, table, "SELECT") - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("non-existent user", func(t *testing.T) { @@ -1217,11 +1217,11 @@ func TestManager_GrantUserAccessToTable(T *testing.T) { // Create a test table _, err := adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(schema), quoteIdent(table))) - assert.NoError(t, err) + test.NoError(t, err) // Try to grant access to non-existent user err = mgr.GrantUserAccessToTable(ctx, username, schema, table, "SELECT") - assert.Error(t, err) + test.Error(t, err) }) T.Run("non-existent table", func(t *testing.T) { @@ -1244,11 +1244,11 @@ func TestManager_GrantUserAccessToTable(T *testing.T) { // Create user err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Try to grant access to non-existent table err = mgr.GrantUserAccessToTable(ctx, username, schema, table, "SELECT") - assert.Error(t, err) + test.Error(t, err) }) } @@ -1274,12 +1274,12 @@ func TestManager_SQLInjectionProtection(T *testing.T) { // This should not cause SQL injection due to proper quoting err := mgr.CreateUser(ctx, maliciousUsername, password) - assert.NoError(t, err) + test.NoError(t, err) // Verify the user was created with the literal name (not executed as SQL) exists, err := mgr.UserExists(ctx, maliciousUsername) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("password injection in CreateUser", func(t *testing.T) { @@ -1301,12 +1301,12 @@ func TestManager_SQLInjectionProtection(T *testing.T) { // This should not cause SQL injection due to proper quoting err := mgr.CreateUser(ctx, username, maliciousPassword) - assert.NoError(t, err) + test.NoError(t, err) // Verify the user was created exists, err := mgr.UserExists(ctx, username) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("database name injection in CreateDatabase", func(t *testing.T) { @@ -1329,16 +1329,16 @@ func TestManager_SQLInjectionProtection(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // This should not cause SQL injection due to proper quoting err = mgr.CreateDatabase(ctx, maliciousDbName, username) - assert.NoError(t, err) + test.NoError(t, err) // Verify the database was created with the literal name exists, err := mgr.DatabaseExists(ctx, maliciousDbName) - assert.NoError(t, err) - assert.True(t, exists) + test.NoError(t, err) + test.True(t, exists) }) T.Run("table name injection in GrantUserAccessToTable", func(t *testing.T) { @@ -1363,18 +1363,18 @@ func TestManager_SQLInjectionProtection(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Create a test table with the malicious name _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(schema), quoteIdent(maliciousTable))) - assert.NoError(t, err) + test.NoError(t, err) // This should not cause SQL injection due to proper quoting err = mgr.GrantUserAccessToTable(ctx, username, schema, maliciousTable, "SELECT") - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("schema name injection in GrantUserAccessToTable", func(t *testing.T) { @@ -1399,21 +1399,21 @@ func TestManager_SQLInjectionProtection(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Create schema and table with malicious names _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA %s", quoteIdent(maliciousSchema))) - assert.NoError(t, err) + test.NoError(t, err) _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(maliciousSchema), quoteIdent(table))) - assert.NoError(t, err) + test.NoError(t, err) // This should not cause SQL injection due to proper quoting err = mgr.GrantUserAccessToTable(ctx, username, maliciousSchema, table, "SELECT") - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -1442,11 +1442,11 @@ func TestManager_ErrorCases(T *testing.T) { // Operations with cancelled context should fail err := mgr.CreateUser(cancelledCtx, username, password) - assert.Error(t, err) + test.Error(t, err) exists, err := mgr.UserExists(cancelledCtx, username) - assert.Error(t, err) - assert.False(t, exists) + test.Error(t, err) + test.False(t, exists) }) T.Run("context timeout", func(t *testing.T) { @@ -1474,7 +1474,7 @@ func TestManager_ErrorCases(T *testing.T) { // Operations with timed out context should fail err := mgr.CreateUser(timeoutCtx, username, password) - assert.Error(t, err) + test.Error(t, err) }) T.Run("empty username", func(t *testing.T) { @@ -1495,7 +1495,7 @@ func TestManager_ErrorCases(T *testing.T) { // Empty username should fail err := mgr.CreateUser(ctx, username, password) - assert.Error(t, err) + test.Error(t, err) }) T.Run("empty database name", func(t *testing.T) { @@ -1517,11 +1517,11 @@ func TestManager_ErrorCases(T *testing.T) { // Create user first err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) // Empty database name should fail err = mgr.CreateDatabase(ctx, databaseName, username) - assert.Error(t, err) + test.Error(t, err) }) T.Run("empty table name in GrantUserAccessToTable", func(t *testing.T) { @@ -1545,14 +1545,14 @@ func TestManager_ErrorCases(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Empty table name should fail err = mgr.GrantUserAccessToTable(ctx, username, schema, table, "SELECT") - assert.Error(t, err) + test.Error(t, err) }) T.Run("empty schema name in GrantUserAccessToTable", func(t *testing.T) { @@ -1576,14 +1576,14 @@ func TestManager_ErrorCases(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Empty schema name should fail err = mgr.GrantUserAccessToTable(ctx, username, schema, table, "SELECT") - assert.Error(t, err) + test.Error(t, err) }) T.Run("empty privilege in GrantUserAccessToTable", func(t *testing.T) { @@ -1608,19 +1608,19 @@ func TestManager_ErrorCases(T *testing.T) { // Create user and database err := mgr.CreateUser(ctx, username, password) - assert.NoError(t, err) + test.NoError(t, err) err = mgr.CreateDatabase(ctx, databaseName, username) - assert.NoError(t, err) + test.NoError(t, err) // Create a test table _, err = adminDB.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", quoteIdent(schema), quoteIdent(table))) - assert.NoError(t, err) + test.NoError(t, err) // Empty privilege should fail err = mgr.GrantUserAccessToTable(ctx, username, schema, table, privilege) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid privilege") + test.Error(t, err) + test.StrContains(t, err.Error(), "invalid privilege") }) } @@ -1645,18 +1645,18 @@ func TestNewManager(T *testing.T) { password := "hunter2" databaseName := "records" - assert.NoError(t, mgr.CreateUser(ctx, username, password)) - assert.NoError(t, mgr.CreateDatabase(ctx, databaseName, username)) + test.NoError(t, mgr.CreateUser(ctx, username, password)) + test.NoError(t, mgr.CreateDatabase(ctx, databaseName, username)) canAccess, err := mgr.UserCanAccessDatabase(ctx, username, databaseName) - assert.NoError(t, err) - assert.True(t, canAccess) + test.NoError(t, err) + test.True(t, canAccess) db2, err := sql.Open("pgx", buildConnectionString(t, container, databaseName, username, password)) - require.NoError(t, err) + must.NoError(t, err) var dbName string db2.QueryRowContext(ctx, `SELECT current_database()`).Scan(&dbName) - assert.Equal(t, databaseName, dbName) + test.EqOp(t, databaseName, dbName) }) } diff --git a/database/sqlite/do_test.go b/database/sqlite/do_test.go index 93fbfc2..41d67fe 100644 --- a/database/sqlite/do_test.go +++ b/database/sqlite/do_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterDatabaseClient(T *testing.T) { @@ -33,7 +33,7 @@ func TestRegisterDatabaseClient(T *testing.T) { RegisterDatabaseClient(i) client, err := do.Invoke[database.Client](i) - require.NoError(t, err) - assert.NotNil(t, client) + must.NoError(t, err) + test.NotNil(t, client) }) } diff --git a/database/sqlite/sqlite_test.go b/database/sqlite/sqlite_test.go index dac4e86..75ab993 100644 --- a/database/sqlite/sqlite_test.go +++ b/database/sqlite/sqlite_test.go @@ -12,8 +12,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) // testClientConfig is a test implementation of database.ClientConfig. @@ -65,7 +65,7 @@ func buildTestClient(t *testing.T) (*Client, sqlmock.Sqlmock) { t.Helper() fakeDB, sqlMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: fakeDB, @@ -97,7 +97,7 @@ func TestQuerier_IsReady(T *testing.T) { // same DB for read/write, so only one ping db.ExpectPing().WillDelayFor(0) - assert.True(t, c.IsReady(ctx)) + test.True(t, c.IsReady(ctx)) }) T.Run("with read DB ping error", func(t *testing.T) { @@ -109,7 +109,7 @@ func TestQuerier_IsReady(T *testing.T) { db.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) T.Run("with write DB ping error", func(t *testing.T) { @@ -118,10 +118,10 @@ func TestQuerier_IsReady(T *testing.T) { ctx := t.Context() readDB, readMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -134,7 +134,7 @@ func TestQuerier_IsReady(T *testing.T) { readMock.ExpectPing().WillDelayFor(0) writeMock.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) T.Run("exhausting all available queries", func(t *testing.T) { @@ -148,7 +148,7 @@ func TestQuerier_IsReady(T *testing.T) { db.ExpectPing().WillReturnError(errors.New("blah")) - assert.False(t, c.IsReady(ctx)) + test.False(t, c.IsReady(ctx)) }) } @@ -166,8 +166,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with no connection strings", func(t *testing.T) { @@ -178,8 +178,8 @@ func TestProvideDatabaseClient(T *testing.T) { exampleConfig := &testClientConfig{} actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.Nil(t, actual) - assert.Error(t, err) + test.Nil(t, actual) + test.Error(t, err) }) T.Run("with only read connection string", func(t *testing.T) { @@ -193,8 +193,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with only write connection string", func(t *testing.T) { @@ -208,8 +208,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with metrics provider", func(t *testing.T) { @@ -223,8 +223,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, metrics.NewNoopMetricsProvider()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with metrics provider and single connection", func(t *testing.T) { @@ -238,8 +238,8 @@ func TestProvideDatabaseClient(T *testing.T) { } actual, err := ProvideDatabaseClient(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), exampleConfig, metrics.NewNoopMetricsProvider()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) } @@ -249,7 +249,7 @@ func TestDefaultTimeFunc(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotZero(t, defaultTimeFunc()) + test.False(t, defaultTimeFunc().IsZero()) }) } @@ -261,7 +261,7 @@ func TestQuerier_currentTime(T *testing.T) { c, _ := buildTestClient(t) - assert.NotEmpty(t, c.CurrentTime()) + test.False(t, c.CurrentTime().IsZero()) }) T.Run("handles nil", func(t *testing.T) { @@ -269,7 +269,7 @@ func TestQuerier_currentTime(T *testing.T) { var c *Client - assert.NotEmpty(t, c.CurrentTime()) + test.False(t, c.CurrentTime().IsZero()) }) } @@ -286,7 +286,7 @@ func TestQuerier_rollbackTransaction(T *testing.T) { db.ExpectRollback().WillReturnError(errors.New("blah")) tx, err := c.writeDB.BeginTx(ctx, nil) - require.NoError(t, err) + must.NoError(t, err) c.RollbackTransaction(ctx, tx) }) @@ -301,7 +301,7 @@ func TestQuerier_rollbackTransaction(T *testing.T) { db.ExpectRollback() tx, err := c.writeDB.BeginTx(ctx, nil) - require.NoError(t, err) + must.NoError(t, err) c.RollbackTransaction(ctx, tx) }) @@ -315,7 +315,7 @@ func TestClient_ReadDB(T *testing.T) { c, _ := buildTestClient(t) - assert.NotNil(t, c.ReadDB()) + test.NotNil(t, c.ReadDB()) }) } @@ -327,7 +327,7 @@ func TestClient_WriteDB(T *testing.T) { c, _ := buildTestClient(t) - assert.NotNil(t, c.WriteDB()) + test.NotNil(t, c.WriteDB()) }) } @@ -341,17 +341,17 @@ func TestClient_Close(T *testing.T) { db.ExpectClose() - assert.NoError(t, c.Close()) + test.NoError(t, c.Close()) }) T.Run("with separate read and write DBs", func(t *testing.T) { t.Parallel() readDB, readMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -363,7 +363,7 @@ func TestClient_Close(T *testing.T) { readMock.ExpectClose() writeMock.ExpectClose() - assert.NoError(t, c.Close()) + test.NoError(t, c.Close()) }) T.Run("with read close error", func(t *testing.T) { @@ -373,17 +373,17 @@ func TestClient_Close(T *testing.T) { db.ExpectClose().WillReturnError(errors.New("blah")) - assert.Error(t, c.Close()) + test.Error(t, c.Close()) }) T.Run("with write close error", func(t *testing.T) { t.Parallel() readDB, readMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) writeDB, writeMock, err := sqlmock.New() - require.NoError(t, err) + must.NoError(t, err) c := &Client{ readDB: readDB, @@ -395,6 +395,6 @@ func TestClient_Close(T *testing.T) { readMock.ExpectClose() writeMock.ExpectClose().WillReturnError(errors.New("blah")) - assert.Error(t, c.Close()) + test.Error(t, c.Close()) }) } diff --git a/database/sqlite/tableaccess/access_manager_test.go b/database/sqlite/tableaccess/access_manager_test.go index f70f59b..f8ccb83 100644 --- a/database/sqlite/tableaccess/access_manager_test.go +++ b/database/sqlite/tableaccess/access_manager_test.go @@ -3,7 +3,7 @@ package tableaccess import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestManager_CreateUser(T *testing.T) { @@ -14,7 +14,7 @@ func TestManager_CreateUser(T *testing.T) { m := NewManager() err := m.CreateUser(t.Context(), "user", "pass") - assert.ErrorIs(t, err, ErrNotSupported) + test.ErrorIs(t, err, ErrNotSupported) }) } @@ -26,7 +26,7 @@ func TestManager_DeleteUser(T *testing.T) { m := NewManager() err := m.DeleteUser(t.Context(), "user") - assert.ErrorIs(t, err, ErrNotSupported) + test.ErrorIs(t, err, ErrNotSupported) }) } @@ -38,7 +38,7 @@ func TestManager_CreateDatabase(T *testing.T) { m := NewManager() err := m.CreateDatabase(t.Context(), "db", "owner") - assert.ErrorIs(t, err, ErrNotSupported) + test.ErrorIs(t, err, ErrNotSupported) }) } @@ -50,7 +50,7 @@ func TestManager_DeleteDatabase(T *testing.T) { m := NewManager() err := m.DeleteDatabase(t.Context(), "db") - assert.ErrorIs(t, err, ErrNotSupported) + test.ErrorIs(t, err, ErrNotSupported) }) } @@ -62,8 +62,8 @@ func TestManager_UserExists(T *testing.T) { m := NewManager() exists, err := m.UserExists(t.Context(), "user") - assert.False(t, exists) - assert.ErrorIs(t, err, ErrNotSupported) + test.False(t, exists) + test.ErrorIs(t, err, ErrNotSupported) }) } @@ -75,8 +75,8 @@ func TestManager_DatabaseExists(T *testing.T) { m := NewManager() exists, err := m.DatabaseExists(t.Context(), "db") - assert.False(t, exists) - assert.ErrorIs(t, err, ErrNotSupported) + test.False(t, exists) + test.ErrorIs(t, err, ErrNotSupported) }) } @@ -88,7 +88,7 @@ func TestManager_GrantUserAccessToTable(T *testing.T) { m := NewManager() err := m.GrantUserAccessToTable(t.Context(), "user", "schema", "table", "SELECT") - assert.ErrorIs(t, err, ErrNotSupported) + test.ErrorIs(t, err, ErrNotSupported) }) } @@ -100,7 +100,7 @@ func TestManager_UserCanAccessDatabase(T *testing.T) { m := NewManager() canAccess, err := m.UserCanAccessDatabase(t.Context(), "user", "db") - assert.False(t, canAccess) - assert.ErrorIs(t, err, ErrNotSupported) + test.False(t, canAccess) + test.ErrorIs(t, err, ErrNotSupported) }) } diff --git a/distributedlock/config/config_test.go b/distributedlock/config/config_test.go index 856df70..6808f1f 100644 --- a/distributedlock/config/config_test.go +++ b/distributedlock/config/config_test.go @@ -18,8 +18,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -47,7 +46,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { KeyPrefix: "lock:", }, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("postgres provider", func(t *testing.T) { @@ -56,43 +55,43 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: PostgresProvider, Postgres: &pglock.Config{}, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("memory provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: MemoryProvider} - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("noop provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: NoopProvider} - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("redis without config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: RedisProvider} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("postgres without config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: PostgresProvider} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("invalid provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: "made-up"} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("empty provider is valid (noop)", func(t *testing.T) { t.Parallel() cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) } @@ -109,7 +108,7 @@ func TestProvideLocker(T *testing.T) { metrics.NewNoopMetricsProvider(), nil, ) - assert.ErrorIs(t, err, distributedlock.ErrNilConfig) + test.ErrorIs(t, err, distributedlock.ErrNilConfig) }) T.Run("memory provider returns a working locker", func(t *testing.T) { @@ -122,11 +121,11 @@ func TestProvideLocker(T *testing.T) { metrics.NewNoopMetricsProvider(), nil, ) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) lock, err := l.Acquire(t.Context(), "k", time.Second) - require.NoError(t, err) - require.NoError(t, lock.Release(t.Context())) + must.NoError(t, err) + must.NoError(t, lock.Release(t.Context())) }) T.Run("noop provider", func(t *testing.T) { @@ -139,8 +138,8 @@ func TestProvideLocker(T *testing.T) { metrics.NewNoopMetricsProvider(), nil, ) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) }) T.Run("unknown provider returns noop", func(t *testing.T) { @@ -153,8 +152,8 @@ func TestProvideLocker(T *testing.T) { metrics.NewNoopMetricsProvider(), nil, ) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) }) T.Run("empty provider returns noop", func(t *testing.T) { @@ -167,8 +166,8 @@ func TestProvideLocker(T *testing.T) { metrics.NewNoopMetricsProvider(), nil, ) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) }) T.Run("provider with whitespace returns noop", func(t *testing.T) { @@ -181,8 +180,8 @@ func TestProvideLocker(T *testing.T) { metrics.NewNoopMetricsProvider(), nil, ) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) }) T.Run("redis provider", func(t *testing.T) { @@ -201,8 +200,8 @@ func TestProvideLocker(T *testing.T) { metrics.NewNoopMetricsProvider(), nil, ) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) }) T.Run("postgres provider", func(t *testing.T) { @@ -218,8 +217,8 @@ func TestProvideLocker(T *testing.T) { metrics.NewNoopMetricsProvider(), &stubDBClient{}, ) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) }) T.Run("circuit breaker init failure", func(t *testing.T) { @@ -248,9 +247,9 @@ func TestProvideLocker(T *testing.T) { mp, nil, ) - require.Error(t, err) - assert.Nil(t, l) - assert.Contains(t, err.Error(), "distributedlock circuit breaker") + must.Error(t, err) + test.Nil(t, l) + test.StrContains(t, err.Error(), "distributedlock circuit breaker") test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) diff --git a/distributedlock/memory/memory_test.go b/distributedlock/memory/memory_test.go index ee3ca8b..88a66a6 100644 --- a/distributedlock/memory/memory_test.go +++ b/distributedlock/memory/memory_test.go @@ -8,15 +8,15 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/distributedlock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func newTestLocker(t *testing.T) distributedlock.Locker { t.Helper() l, err := NewLocker(nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) return l } @@ -26,8 +26,8 @@ func TestNewLocker(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() l, err := NewLocker(nil, nil, nil) - require.NoError(t, err) - assert.NotNil(t, l) + must.NoError(t, err) + test.NotNil(t, l) }) } @@ -38,50 +38,50 @@ func TestLocker_Acquire(T *testing.T) { t.Parallel() l := newTestLocker(t) lock, err := l.Acquire(t.Context(), "k", time.Second) - require.NoError(t, err) - require.NotNil(t, lock) - assert.Equal(t, "k", lock.Key()) - assert.Equal(t, time.Second, lock.TTL()) + must.NoError(t, err) + must.NotNil(t, lock) + test.EqOp(t, "k", lock.Key()) + test.EqOp(t, time.Second, lock.TTL()) }) T.Run("contended", func(t *testing.T) { t.Parallel() l := newTestLocker(t) _, err := l.Acquire(t.Context(), "shared", time.Minute) - require.NoError(t, err) + must.NoError(t, err) _, err = l.Acquire(t.Context(), "shared", time.Minute) - require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) + must.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) }) T.Run("re-acquire after expiry", func(t *testing.T) { t.Parallel() l := newTestLocker(t) _, err := l.Acquire(t.Context(), "exp", 50*time.Millisecond) - require.NoError(t, err) + must.NoError(t, err) time.Sleep(80 * time.Millisecond) _, err = l.Acquire(t.Context(), "exp", time.Second) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("rejects empty key", func(t *testing.T) { t.Parallel() l := newTestLocker(t) _, err := l.Acquire(t.Context(), "", time.Second) - require.ErrorIs(t, err, distributedlock.ErrEmptyKey) + must.ErrorIs(t, err, distributedlock.ErrEmptyKey) }) T.Run("rejects zero TTL", func(t *testing.T) { t.Parallel() l := newTestLocker(t) _, err := l.Acquire(t.Context(), "k", 0) - require.ErrorIs(t, err, distributedlock.ErrInvalidTTL) + must.ErrorIs(t, err, distributedlock.ErrInvalidTTL) }) T.Run("rejects negative TTL", func(t *testing.T) { t.Parallel() l := newTestLocker(t) _, err := l.Acquire(t.Context(), "k", -time.Second) - require.ErrorIs(t, err, distributedlock.ErrInvalidTTL) + must.ErrorIs(t, err, distributedlock.ErrInvalidTTL) }) } @@ -92,36 +92,36 @@ func TestLocker_Release(T *testing.T) { t.Parallel() l := newTestLocker(t) lock, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, lock.Release(t.Context())) + must.NoError(t, err) + must.NoError(t, lock.Release(t.Context())) }) T.Run("released lock can be reacquired", func(t *testing.T) { t.Parallel() l := newTestLocker(t) first, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, first.Release(t.Context())) + must.NoError(t, err) + must.NoError(t, first.Release(t.Context())) _, err = l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("double release returns ErrLockNotHeld", func(t *testing.T) { t.Parallel() l := newTestLocker(t) lock, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, lock.Release(t.Context())) - require.ErrorIs(t, lock.Release(t.Context()), distributedlock.ErrLockNotHeld) + must.NoError(t, err) + must.NoError(t, lock.Release(t.Context())) + must.ErrorIs(t, lock.Release(t.Context()), distributedlock.ErrLockNotHeld) }) T.Run("release after expiration returns ErrLockNotHeld", func(t *testing.T) { t.Parallel() l := newTestLocker(t) lock, err := l.Acquire(t.Context(), "k", 50*time.Millisecond) - require.NoError(t, err) + must.NoError(t, err) time.Sleep(80 * time.Millisecond) - require.ErrorIs(t, lock.Release(t.Context()), distributedlock.ErrLockNotHeld) + must.ErrorIs(t, lock.Release(t.Context()), distributedlock.ErrLockNotHeld) }) } @@ -132,29 +132,29 @@ func TestLocker_Refresh(T *testing.T) { t.Parallel() l := newTestLocker(t) lock, err := l.Acquire(t.Context(), "k", 50*time.Millisecond) - require.NoError(t, err) - require.NoError(t, lock.Refresh(t.Context(), 5*time.Second)) + must.NoError(t, err) + must.NoError(t, lock.Refresh(t.Context(), 5*time.Second)) // Even after the original TTL elapses, the lock is still held. time.Sleep(80 * time.Millisecond) _, err = l.Acquire(t.Context(), "k", time.Second) - require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) + must.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) }) T.Run("refresh after expiration returns ErrLockNotHeld", func(t *testing.T) { t.Parallel() l := newTestLocker(t) lock, err := l.Acquire(t.Context(), "k", 50*time.Millisecond) - require.NoError(t, err) + must.NoError(t, err) time.Sleep(80 * time.Millisecond) - require.ErrorIs(t, lock.Refresh(t.Context(), time.Second), distributedlock.ErrLockNotHeld) + must.ErrorIs(t, lock.Refresh(t.Context(), time.Second), distributedlock.ErrLockNotHeld) }) T.Run("rejects invalid TTL", func(t *testing.T) { t.Parallel() l := newTestLocker(t) lock, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, lock.Refresh(t.Context(), 0), distributedlock.ErrInvalidTTL) + must.NoError(t, err) + must.ErrorIs(t, lock.Refresh(t.Context(), 0), distributedlock.ErrInvalidTTL) }) } @@ -163,7 +163,7 @@ func TestLocker_Ping(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - require.NoError(t, newTestLocker(t).Ping(t.Context())) + must.NoError(t, newTestLocker(t).Ping(t.Context())) }) } @@ -174,13 +174,13 @@ func TestLocker_Close(T *testing.T) { t.Parallel() l := newTestLocker(t) lock, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, l.Close()) + must.NoError(t, err) + must.NoError(t, l.Close()) // The previous handle now sees the lock as not-held. - require.ErrorIs(t, lock.Release(t.Context()), distributedlock.ErrLockNotHeld) + must.ErrorIs(t, lock.Release(t.Context()), distributedlock.ErrLockNotHeld) // And the key is acquirable again. _, err = l.Acquire(t.Context(), "k", time.Second) - require.NoError(t, err) + must.NoError(t, err) }) } @@ -205,6 +205,6 @@ func TestLocker_Concurrency(T *testing.T) { } wg.Wait() - assert.Equal(t, int64(1), winners.Load()) + test.EqOp(t, int64(1), winners.Load()) }) } diff --git a/distributedlock/noop/noop_test.go b/distributedlock/noop/noop_test.go index 1f9711a..262a3c8 100644 --- a/distributedlock/noop/noop_test.go +++ b/distributedlock/noop/noop_test.go @@ -4,8 +4,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNewLocker(T *testing.T) { @@ -13,7 +13,7 @@ func TestNewLocker(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewLocker()) + test.NotNil(t, NewLocker()) }) } @@ -24,19 +24,19 @@ func TestLocker_Acquire(T *testing.T) { t.Parallel() l := NewLocker() lock, err := l.Acquire(t.Context(), "k", time.Second) - require.NoError(t, err) - require.NotNil(t, lock) - assert.Equal(t, "k", lock.Key()) - assert.Equal(t, time.Second, lock.TTL()) + must.NoError(t, err) + must.NotNil(t, lock) + test.EqOp(t, "k", lock.Key()) + test.EqOp(t, time.Second, lock.TTL()) }) T.Run("contended acquires both succeed", func(t *testing.T) { t.Parallel() l := NewLocker() _, err := l.Acquire(t.Context(), "shared", time.Second) - require.NoError(t, err) + must.NoError(t, err) _, err = l.Acquire(t.Context(), "shared", time.Second) - require.NoError(t, err) + must.NoError(t, err) }) } @@ -45,7 +45,7 @@ func TestLocker_Ping(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - require.NoError(t, NewLocker().Ping(t.Context())) + must.NoError(t, NewLocker().Ping(t.Context())) }) } @@ -54,7 +54,7 @@ func TestLocker_Close(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - require.NoError(t, NewLocker().Close()) + must.NoError(t, NewLocker().Close()) }) } @@ -64,16 +64,16 @@ func TestLock_ReleaseAndRefresh(T *testing.T) { T.Run("release is a no-op", func(t *testing.T) { t.Parallel() l, err := NewLocker().Acquire(t.Context(), "k", time.Second) - require.NoError(t, err) - require.NoError(t, l.Release(t.Context())) - require.NoError(t, l.Release(t.Context())) + must.NoError(t, err) + must.NoError(t, l.Release(t.Context())) + must.NoError(t, l.Release(t.Context())) }) T.Run("refresh updates ttl", func(t *testing.T) { t.Parallel() l, err := NewLocker().Acquire(t.Context(), "k", time.Second) - require.NoError(t, err) - require.NoError(t, l.Refresh(t.Context(), 5*time.Second)) - assert.Equal(t, 5*time.Second, l.TTL()) + must.NoError(t, err) + must.NoError(t, l.Refresh(t.Context(), 5*time.Second)) + test.EqOp(t, 5*time.Second, l.TTL()) }) } diff --git a/distributedlock/postgres/config_test.go b/distributedlock/postgres/config_test.go index 0e06cca..61e6ba1 100644 --- a/distributedlock/postgres/config_test.go +++ b/distributedlock/postgres/config_test.go @@ -3,7 +3,7 @@ package postgres import ( "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -12,12 +12,12 @@ func TestConfig_ValidateWithContext(T *testing.T) { T.Run("happy path zero namespace", func(t *testing.T) { t.Parallel() cfg := &Config{} - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("happy path explicit namespace", func(t *testing.T) { t.Parallel() cfg := &Config{Namespace: 42} - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/distributedlock/postgres/postgres_test.go b/distributedlock/postgres/postgres_test.go index 31863d9..1d90a3e 100644 --- a/distributedlock/postgres/postgres_test.go +++ b/distributedlock/postgres/postgres_test.go @@ -19,8 +19,8 @@ import ( "github.com/DATA-DOG/go-sqlmock" _ "github.com/jackc/pgx/v5/stdlib" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "github.com/testcontainers/testcontainers-go" postgrescontainer "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" @@ -59,17 +59,17 @@ func buildContainerBackedPostgres(t *testing.T) (client *testDBClient, shutdown postgrescontainer.WithPassword("locktest"), testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForLog("database system is ready to accept connections").WithOccurrence(2)), ) - require.NoError(t, err) - require.NotNil(t, container) + must.NoError(t, err) + must.NotNil(t, container) connStr, err := container.ConnectionString(ctx, "sslmode=disable") - require.NoError(t, err) + must.NoError(t, err) db, err := sql.Open("pgx", connStr) - require.NoError(t, err) + must.NoError(t, err) // Allow plenty of conns so the parallel subtests don't starve. db.SetMaxOpenConns(64) - require.NoError(t, db.PingContext(ctx)) + must.NoError(t, db.PingContext(ctx)) return &testDBClient{db: db}, func(ctx context.Context) error { _ = db.Close() @@ -80,8 +80,8 @@ func buildContainerBackedPostgres(t *testing.T) (client *testDBClient, shutdown func newTestLocker(t *testing.T, client database.Client) distributedlock.Locker { t.Helper() l, err := NewPostgresLocker(&Config{}, client, nil, nil, nil, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) return l } @@ -125,15 +125,15 @@ func (p *errorAtCallProvider) NewFloat64Histogram(name string, options ...metric func buildSqlmockClient(t *testing.T) (*testDBClient, sqlmock.Sqlmock) { t.Helper() db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - require.NoError(t, err) + must.NoError(t, err) return &testDBClient{db: db}, mock } func newTestLockerWithCB(t *testing.T, client database.Client, cb circuitbreaking.CircuitBreaker) distributedlock.Locker { t.Helper() l, err := NewPostgresLocker(&Config{}, client, nil, nil, nil, cb) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) return l } @@ -143,13 +143,13 @@ func TestNewPostgresLocker(T *testing.T) { T.Run("nil config", func(t *testing.T) { t.Parallel() _, err := NewPostgresLocker(nil, &testDBClient{}, nil, nil, nil, cbnoop.NewCircuitBreaker()) - require.ErrorIs(t, err, distributedlock.ErrNilConfig) + must.ErrorIs(t, err, distributedlock.ErrNilConfig) }) T.Run("nil database", func(t *testing.T) { t.Parallel() _, err := NewPostgresLocker(&Config{}, nil, nil, nil, nil, cbnoop.NewCircuitBreaker()) - require.ErrorIs(t, err, distributedlock.ErrNilDatabaseClient) + must.ErrorIs(t, err, distributedlock.ErrNilDatabaseClient) }) T.Run("standard happy path", func(t *testing.T) { @@ -157,8 +157,8 @@ func TestNewPostgresLocker(T *testing.T) { client, _ := buildSqlmockClient(t) t.Cleanup(func() { _ = client.Close() }) l, err := NewPostgresLocker(&Config{Namespace: 7}, client, nil, nil, nil, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) }) // Each Int64Counter creation has its own error branch; exercise them all so @@ -170,7 +170,7 @@ func TestNewPostgresLocker(T *testing.T) { t.Cleanup(func() { _ = client.Close() }) mp := newErrorAtCallProvider(idx, false) _, err := NewPostgresLocker(&Config{}, client, nil, nil, mp, cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) } @@ -180,7 +180,7 @@ func TestNewPostgresLocker(T *testing.T) { t.Cleanup(func() { _ = client.Close() }) mp := newErrorAtCallProvider(0, true) _, err := NewPostgresLocker(&Config{}, client, nil, nil, mp, cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) } @@ -198,11 +198,11 @@ func TestLocker_Acquire_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) got, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NotNil(t, got) - assert.Equal(t, "k", got.Key()) - assert.Equal(t, time.Minute, got.TTL()) - require.NoError(t, mock.ExpectationsWereMet()) + must.NoError(t, err) + must.NotNil(t, got) + test.EqOp(t, "k", got.Key()) + test.EqOp(t, time.Minute, got.TTL()) + must.NoError(t, mock.ExpectationsWereMet()) }) T.Run("rejects empty key", func(t *testing.T) { @@ -211,7 +211,7 @@ func TestLocker_Acquire_Unit(T *testing.T) { t.Cleanup(func() { _ = client.Close() }) l := newTestLocker(t, client) _, err := l.Acquire(t.Context(), "", time.Minute) - require.ErrorIs(t, err, distributedlock.ErrEmptyKey) + must.ErrorIs(t, err, distributedlock.ErrEmptyKey) }) T.Run("rejects zero TTL", func(t *testing.T) { @@ -220,7 +220,7 @@ func TestLocker_Acquire_Unit(T *testing.T) { t.Cleanup(func() { _ = client.Close() }) l := newTestLocker(t, client) _, err := l.Acquire(t.Context(), "k", 0) - require.ErrorIs(t, err, distributedlock.ErrInvalidTTL) + must.ErrorIs(t, err, distributedlock.ErrInvalidTTL) }) T.Run("rejects negative TTL", func(t *testing.T) { @@ -229,7 +229,7 @@ func TestLocker_Acquire_Unit(T *testing.T) { t.Cleanup(func() { _ = client.Close() }) l := newTestLocker(t, client) _, err := l.Acquire(t.Context(), "k", -time.Second) - require.ErrorIs(t, err, distributedlock.ErrInvalidTTL) + must.ErrorIs(t, err, distributedlock.ErrInvalidTTL) }) T.Run("blocked by circuit breaker", func(t *testing.T) { @@ -241,8 +241,8 @@ func TestLocker_Acquire_Unit(T *testing.T) { } l := newTestLockerWithCB(t, client, cb) _, err := l.Acquire(t.Context(), "k", time.Minute) - require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - require.NotEmpty(t, cb.CannotProceedCalls()) + must.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + must.SliceNotEmpty(t, cb.CannotProceedCalls()) }) T.Run("Conn reservation failure", func(t *testing.T) { @@ -251,10 +251,10 @@ func TestLocker_Acquire_Unit(T *testing.T) { mock.ExpectClose() l := newTestLocker(t, client) // Close the underlying DB so Conn() returns an error. - require.NoError(t, client.Close()) + must.NoError(t, client.Close()) _, err := l.Acquire(t.Context(), "k", time.Minute) - require.Error(t, err) + must.Error(t, err) }) T.Run("pg_try_advisory_lock query failure", func(t *testing.T) { @@ -268,8 +268,8 @@ func TestLocker_Acquire_Unit(T *testing.T) { WillReturnError(errors.New("query boom")) _, err := l.Acquire(t.Context(), "k", time.Minute) - require.Error(t, err) - require.NoError(t, mock.ExpectationsWereMet()) + must.Error(t, err) + must.NoError(t, mock.ExpectationsWereMet()) }) T.Run("contention returns ErrLockNotAcquired", func(t *testing.T) { @@ -283,8 +283,8 @@ func TestLocker_Acquire_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(false)) _, err := l.Acquire(t.Context(), "k", time.Minute) - require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) - require.NoError(t, mock.ExpectationsWereMet()) + must.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) + must.NoError(t, mock.ExpectationsWereMet()) }) } @@ -305,9 +305,9 @@ func TestLocker_Release_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, h.Release(t.Context())) - require.NoError(t, mock.ExpectationsWereMet()) + must.NoError(t, err) + must.NoError(t, h.Release(t.Context())) + must.NoError(t, mock.ExpectationsWereMet()) }) T.Run("blocked by circuit breaker", func(t *testing.T) { @@ -329,10 +329,10 @@ func TestLocker_Release_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, h.Release(t.Context()), circuitbreaking.ErrCircuitBroken) - require.Len(t, cb.CannotProceedCalls(), 2) - require.Len(t, cb.SucceededCalls(), 1) + must.NoError(t, err) + must.ErrorIs(t, h.Release(t.Context()), circuitbreaking.ErrCircuitBroken) + must.SliceLen(t, 2, cb.CannotProceedCalls()) + must.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("double release returns ErrLockNotHeld", func(t *testing.T) { @@ -349,9 +349,9 @@ func TestLocker_Release_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, h.Release(t.Context())) - require.ErrorIs(t, h.Release(t.Context()), distributedlock.ErrLockNotHeld) + must.NoError(t, err) + must.NoError(t, h.Release(t.Context())) + must.ErrorIs(t, h.Release(t.Context()), distributedlock.ErrLockNotHeld) }) T.Run("releaseLocked deferred conn close error tolerated", func(t *testing.T) { @@ -365,16 +365,16 @@ func TestLocker_Release_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) + must.NoError(t, err) // Force the deferred conn.Close inside releaseLocked to fail by closing // the conn here first. The QueryRowContext on the already-closed conn // will also fail (covered by the SQL failure subtest below); the value // of this case is exercising the deferred Close error branch. inner := h.(*lock) - require.NoError(t, inner.conn.Close()) + must.NoError(t, inner.conn.Close()) - require.Error(t, h.Release(t.Context())) + must.Error(t, h.Release(t.Context())) }) T.Run("releaseLocked SQL failure trips breaker", func(t *testing.T) { @@ -396,10 +396,10 @@ func TestLocker_Release_Unit(T *testing.T) { WillReturnError(errors.New("unlock boom")) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.Error(t, h.Release(t.Context())) - require.Len(t, cb.SucceededCalls(), 1) - require.Len(t, cb.FailedCalls(), 1) + must.NoError(t, err) + must.Error(t, h.Release(t.Context())) + must.SliceLen(t, 1, cb.SucceededCalls()) + must.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -419,9 +419,9 @@ func TestLocker_Refresh_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(1)) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, h.Refresh(t.Context(), 5*time.Minute)) - assert.Equal(t, 5*time.Minute, h.TTL()) + must.NoError(t, err) + must.NoError(t, h.Refresh(t.Context(), 5*time.Minute)) + test.EqOp(t, 5*time.Minute, h.TTL()) }) T.Run("rejects zero TTL", func(t *testing.T) { @@ -435,9 +435,9 @@ func TestLocker_Refresh_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, h.Refresh(t.Context(), 0), distributedlock.ErrInvalidTTL) - assert.Equal(t, time.Minute, h.TTL()) + must.NoError(t, err) + must.ErrorIs(t, h.Refresh(t.Context(), 0), distributedlock.ErrInvalidTTL) + test.EqOp(t, time.Minute, h.TTL()) }) T.Run("blocked by circuit breaker", func(t *testing.T) { @@ -459,10 +459,10 @@ func TestLocker_Refresh_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, h.Refresh(t.Context(), time.Minute), circuitbreaking.ErrCircuitBroken) - require.Len(t, cb.CannotProceedCalls(), 2) - require.Len(t, cb.SucceededCalls(), 1) + must.NoError(t, err) + must.ErrorIs(t, h.Refresh(t.Context(), time.Minute), circuitbreaking.ErrCircuitBroken) + must.SliceLen(t, 2, cb.CannotProceedCalls()) + must.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("refresh after release returns ErrLockNotHeld", func(t *testing.T) { @@ -479,9 +479,9 @@ func TestLocker_Refresh_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, h.Release(t.Context())) - require.ErrorIs(t, h.Refresh(t.Context(), time.Minute), distributedlock.ErrLockNotHeld) + must.NoError(t, err) + must.NoError(t, h.Release(t.Context())) + must.ErrorIs(t, h.Refresh(t.Context(), time.Minute), distributedlock.ErrLockNotHeld) }) T.Run("liveness check failure returns ErrLockNotHeld", func(t *testing.T) { @@ -497,10 +497,10 @@ func TestLocker_Refresh_Unit(T *testing.T) { WillReturnError(errors.New("conn dead")) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, h.Refresh(t.Context(), 5*time.Minute), distributedlock.ErrLockNotHeld) + must.NoError(t, err) + must.ErrorIs(t, h.Refresh(t.Context(), 5*time.Minute), distributedlock.ErrLockNotHeld) // TTL must remain unchanged on failure. - assert.Equal(t, time.Minute, h.TTL()) + test.EqOp(t, time.Minute, h.TTL()) }) } @@ -514,8 +514,8 @@ func TestLocker_PingClose_Unit(T *testing.T) { l := newTestLocker(t, client) mock.ExpectPing() - require.NoError(t, l.Ping(t.Context())) - require.NoError(t, mock.ExpectationsWereMet()) + must.NoError(t, l.Ping(t.Context())) + must.NoError(t, mock.ExpectationsWereMet()) }) T.Run("ping error", func(t *testing.T) { @@ -525,7 +525,7 @@ func TestLocker_PingClose_Unit(T *testing.T) { l := newTestLocker(t, client) mock.ExpectPing().WillReturnError(errors.New("ping boom")) - require.Error(t, l.Ping(t.Context())) + must.Error(t, l.Ping(t.Context())) }) T.Run("close with no outstanding locks", func(t *testing.T) { @@ -533,7 +533,7 @@ func TestLocker_PingClose_Unit(T *testing.T) { client, _ := buildSqlmockClient(t) t.Cleanup(func() { _ = client.Close() }) l := newTestLocker(t, client) - require.NoError(t, l.Close()) + must.NoError(t, l.Close()) }) T.Run("close releases all outstanding locks", func(t *testing.T) { @@ -550,9 +550,9 @@ func TestLocker_PingClose_Unit(T *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"v"}).AddRow(true)) _, err := l.Acquire(t.Context(), "a", time.Minute) - require.NoError(t, err) - require.NoError(t, l.Close()) - require.NoError(t, mock.ExpectationsWereMet()) + must.NoError(t, err) + must.NoError(t, l.Close()) + must.NoError(t, mock.ExpectationsWereMet()) }) T.Run("close surfaces release errors", func(t *testing.T) { @@ -569,8 +569,8 @@ func TestLocker_PingClose_Unit(T *testing.T) { WillReturnError(errors.New("unlock boom")) _, err := l.Acquire(t.Context(), "a", time.Minute) - require.NoError(t, err) - require.Error(t, l.Close()) + must.NoError(t, err) + must.Error(t, l.Close()) }) } @@ -579,17 +579,17 @@ func TestHashLockID(T *testing.T) { T.Run("stable across calls", func(t *testing.T) { t.Parallel() - assert.Equal(t, hashLockID(0, "k"), hashLockID(0, "k")) + test.EqOp(t, hashLockID(0, "k"), hashLockID(0, "k")) }) T.Run("namespace changes the result", func(t *testing.T) { t.Parallel() - assert.NotEqual(t, hashLockID(0, "k"), hashLockID(1, "k")) + test.NotEq(t, hashLockID(0, "k"), hashLockID(1, "k")) }) T.Run("different keys produce different ids", func(t *testing.T) { t.Parallel() - assert.NotEqual(t, hashLockID(0, "a"), hashLockID(0, "b")) + test.NotEq(t, hashLockID(0, "a"), hashLockID(0, "b")) }) } @@ -612,11 +612,11 @@ func TestPostgresLocker_Container(T *testing.T) { key := "happy_" + identifiers.New() lock, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NotNil(t, lock) - assert.Equal(t, key, lock.Key()) - assert.Equal(t, time.Minute, lock.TTL()) - require.NoError(t, lock.Release(ctx)) + must.NoError(t, err) + must.NotNil(t, lock) + test.EqOp(t, key, lock.Key()) + test.EqOp(t, time.Minute, lock.TTL()) + must.NoError(t, lock.Release(ctx)) }) T.Run("Acquire contended on the same locker returns ErrLockNotAcquired", func(t *testing.T) { @@ -626,11 +626,11 @@ func TestPostgresLocker_Container(T *testing.T) { key := "contend_same_" + identifiers.New() first, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) + must.NoError(t, err) t.Cleanup(func() { _ = first.Release(ctx) }) _, err = l.Acquire(ctx, key, time.Minute) - require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) + must.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) }) T.Run("Acquire contended across separate lockers returns ErrLockNotAcquired", func(t *testing.T) { @@ -641,25 +641,25 @@ func TestPostgresLocker_Container(T *testing.T) { key := "contend_cross_" + identifiers.New() first, err := l1.Acquire(ctx, key, time.Minute) - require.NoError(t, err) + must.NoError(t, err) t.Cleanup(func() { _ = first.Release(ctx) }) _, err = l2.Acquire(ctx, key, time.Minute) - require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) + must.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) }) T.Run("Acquire rejects empty key", func(t *testing.T) { t.Parallel() l := newTestLocker(t, client) _, err := l.Acquire(t.Context(), "", time.Minute) - require.ErrorIs(t, err, distributedlock.ErrEmptyKey) + must.ErrorIs(t, err, distributedlock.ErrEmptyKey) }) T.Run("Acquire rejects zero TTL", func(t *testing.T) { t.Parallel() l := newTestLocker(t, client) _, err := l.Acquire(t.Context(), "k", 0) - require.ErrorIs(t, err, distributedlock.ErrInvalidTTL) + must.ErrorIs(t, err, distributedlock.ErrInvalidTTL) }) T.Run("Released lock can be reacquired", func(t *testing.T) { @@ -669,12 +669,12 @@ func TestPostgresLocker_Container(T *testing.T) { key := "reacquire_" + identifiers.New() first, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NoError(t, first.Release(ctx)) + must.NoError(t, err) + must.NoError(t, first.Release(ctx)) second, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NoError(t, second.Release(ctx)) + must.NoError(t, err) + must.NoError(t, second.Release(ctx)) }) T.Run("Double release returns ErrLockNotHeld on second call", func(t *testing.T) { @@ -684,9 +684,9 @@ func TestPostgresLocker_Container(T *testing.T) { key := "double_" + identifiers.New() lock, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NoError(t, lock.Release(ctx)) - require.ErrorIs(t, lock.Release(ctx), distributedlock.ErrLockNotHeld) + must.NoError(t, err) + must.NoError(t, lock.Release(ctx)) + must.ErrorIs(t, lock.Release(ctx), distributedlock.ErrLockNotHeld) }) T.Run("Refresh succeeds and updates local TTL", func(t *testing.T) { @@ -696,11 +696,11 @@ func TestPostgresLocker_Container(T *testing.T) { key := "refresh_" + identifiers.New() lock, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) + must.NoError(t, err) t.Cleanup(func() { _ = lock.Release(ctx) }) - require.NoError(t, lock.Refresh(ctx, 5*time.Minute)) - assert.Equal(t, 5*time.Minute, lock.TTL()) + must.NoError(t, lock.Refresh(ctx, 5*time.Minute)) + test.EqOp(t, 5*time.Minute, lock.TTL()) }) T.Run("Refresh after release returns ErrLockNotHeld", func(t *testing.T) { @@ -710,10 +710,10 @@ func TestPostgresLocker_Container(T *testing.T) { key := "refresh_after_release_" + identifiers.New() lock, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NoError(t, lock.Release(ctx)) + must.NoError(t, err) + must.NoError(t, lock.Release(ctx)) - require.ErrorIs(t, lock.Refresh(ctx, time.Minute), distributedlock.ErrLockNotHeld) + must.ErrorIs(t, lock.Refresh(ctx, time.Minute), distributedlock.ErrLockNotHeld) }) T.Run("Close releases all outstanding locks", func(t *testing.T) { @@ -724,24 +724,24 @@ func TestPostgresLocker_Container(T *testing.T) { keyB := "close_b_" + identifiers.New() _, err := l.Acquire(ctx, keyA, time.Minute) - require.NoError(t, err) + must.NoError(t, err) _, err = l.Acquire(ctx, keyB, time.Minute) - require.NoError(t, err) - require.NoError(t, l.Close()) + must.NoError(t, err) + must.NoError(t, l.Close()) // Both keys are acquirable again from a fresh locker. l2 := newTestLocker(t, client) t.Cleanup(func() { _ = l2.Close() }) _, err = l2.Acquire(ctx, keyA, time.Minute) - require.NoError(t, err) + must.NoError(t, err) _, err = l2.Acquire(ctx, keyB, time.Minute) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("Ping success", func(t *testing.T) { t.Parallel() l := newTestLocker(t, client) - require.NoError(t, l.Ping(t.Context())) + must.NoError(t, l.Ping(t.Context())) }) } diff --git a/distributedlock/redis/config_test.go b/distributedlock/redis/config_test.go index ba698d6..49f46cf 100644 --- a/distributedlock/redis/config_test.go +++ b/distributedlock/redis/config_test.go @@ -5,8 +5,8 @@ import ( platformerrors "github.com/verygoodsoftwarenotvirus/platform/v5/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -15,20 +15,20 @@ func TestConfig_ValidateWithContext(T *testing.T) { T.Run("nil config", func(t *testing.T) { t.Parallel() var cfg *Config - require.ErrorIs(t, cfg.ValidateWithContext(t.Context()), platformerrors.ErrNilInputParameter) + must.ErrorIs(t, cfg.ValidateWithContext(t.Context()), platformerrors.ErrNilInputParameter) }) T.Run("missing addresses", func(t *testing.T) { t.Parallel() cfg := &Config{} err := cfg.ValidateWithContext(t.Context()) - require.Error(t, err) - assert.Contains(t, err.Error(), "addresses") + must.Error(t, err) + test.StrContains(t, err.Error(), "addresses") }) T.Run("happy path", func(t *testing.T) { t.Parallel() cfg := &Config{Addresses: []string{"localhost:6379"}} - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/distributedlock/redis/redis_test.go b/distributedlock/redis/redis_test.go index 21cc2b3..f04c728 100644 --- a/distributedlock/redis/redis_test.go +++ b/distributedlock/redis/redis_test.go @@ -18,8 +18,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/go-redis/redis/v8" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" rediscontainers "github.com/testcontainers/testcontainers-go/modules/redis" "go.opentelemetry.io/otel/metric" ) @@ -36,11 +36,11 @@ func buildContainerBackedRedisConfig(t *testing.T) (cfg *Config, shutdown func(c redisImage, rediscontainers.WithLogLevel(rediscontainers.LogLevelNotice), ) - require.NoError(t, err) - require.NotNil(t, container) + must.NoError(t, err) + must.NotNil(t, container) addr, err := container.ConnectionString(ctx) - require.NoError(t, err) + must.NoError(t, err) cfg = &Config{ Addresses: []string{strings.TrimPrefix(addr, "redis://")}, @@ -52,8 +52,8 @@ func buildContainerBackedRedisConfig(t *testing.T) (cfg *Config, shutdown func(c func newTestLocker(t *testing.T, cfg *Config) distributedlock.Locker { t.Helper() l, err := NewRedisLocker(cfg, nil, nil, nil, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, l) + must.NoError(t, err) + must.NotNil(t, l) return l } @@ -172,17 +172,17 @@ func newUnitLocker(t *testing.T, client redisClient, cb circuitbreaking.CircuitB t.Helper() mp := metrics.NewNoopMetricsProvider() acquireCounter, err := mp.NewInt64Counter("redis_distributed_lock_acquires") - require.NoError(t, err) + must.NoError(t, err) releaseCounter, err := mp.NewInt64Counter("redis_distributed_lock_releases") - require.NoError(t, err) + must.NoError(t, err) refreshCounter, err := mp.NewInt64Counter("redis_distributed_lock_refreshes") - require.NoError(t, err) + must.NoError(t, err) contendCounter, err := mp.NewInt64Counter("redis_distributed_lock_contended") - require.NoError(t, err) + must.NoError(t, err) errCounter, err := mp.NewInt64Counter("redis_distributed_lock_errors") - require.NoError(t, err) + must.NoError(t, err) latencyHist, err := mp.NewFloat64Histogram("redis_distributed_lock_latency_ms") - require.NoError(t, err) + must.NoError(t, err) if cb == nil { cb = cbnoop.NewCircuitBreaker() } @@ -207,25 +207,25 @@ func TestNewRedisLocker(T *testing.T) { T.Run("nil config", func(t *testing.T) { t.Parallel() _, err := NewRedisLocker(nil, nil, nil, nil, cbnoop.NewCircuitBreaker()) - require.ErrorIs(t, err, distributedlock.ErrNilConfig) + must.ErrorIs(t, err, distributedlock.ErrNilConfig) }) T.Run("standard happy path", func(t *testing.T) { t.Parallel() cfg := &Config{Addresses: []string{"localhost:0"}, KeyPrefix: "lock:"} l, err := NewRedisLocker(cfg, nil, nil, nil, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, l) - require.NoError(t, l.Close()) + must.NoError(t, err) + must.NotNil(t, l) + must.NoError(t, l.Close()) }) T.Run("cluster mode happy path", func(t *testing.T) { t.Parallel() cfg := &Config{Addresses: []string{"localhost:0", "localhost:1"}, KeyPrefix: "lock:"} l, err := NewRedisLocker(cfg, nil, nil, nil, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, l) - require.NoError(t, l.Close()) + must.NoError(t, err) + must.NotNil(t, l) + must.NoError(t, l.Close()) }) // Each metric counter creation has its own error branch; exercise them all so @@ -236,7 +236,7 @@ func TestNewRedisLocker(T *testing.T) { cfg := &Config{Addresses: []string{"localhost:0"}} mp := newErrorAtCallProvider(idx, false) _, err := NewRedisLocker(cfg, nil, nil, mp, cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) } @@ -245,7 +245,7 @@ func TestNewRedisLocker(T *testing.T) { cfg := &Config{Addresses: []string{"localhost:0"}} mp := newErrorAtCallProvider(0, true) _, err := NewRedisLocker(cfg, nil, nil, mp, cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) } @@ -258,33 +258,33 @@ func TestLocker_Acquire(T *testing.T) { l := newUnitLocker(t, fc, nil) got, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NotNil(t, got) - assert.Equal(t, "k", got.Key()) - assert.Equal(t, time.Minute, got.TTL()) - assert.Equal(t, "lock:k", fc.lastSetKey) - assert.Equal(t, time.Minute, fc.lastSetTTL) + must.NoError(t, err) + must.NotNil(t, got) + test.EqOp(t, "k", got.Key()) + test.EqOp(t, time.Minute, got.TTL()) + test.EqOp(t, "lock:k", fc.lastSetKey) + test.EqOp(t, time.Minute, fc.lastSetTTL) }) T.Run("rejects empty key", func(t *testing.T) { t.Parallel() l := newUnitLocker(t, &fakeRedisClient{}, nil) _, err := l.Acquire(t.Context(), "", time.Minute) - require.ErrorIs(t, err, distributedlock.ErrEmptyKey) + must.ErrorIs(t, err, distributedlock.ErrEmptyKey) }) T.Run("rejects zero TTL", func(t *testing.T) { t.Parallel() l := newUnitLocker(t, &fakeRedisClient{}, nil) _, err := l.Acquire(t.Context(), "k", 0) - require.ErrorIs(t, err, distributedlock.ErrInvalidTTL) + must.ErrorIs(t, err, distributedlock.ErrInvalidTTL) }) T.Run("rejects negative TTL", func(t *testing.T) { t.Parallel() l := newUnitLocker(t, &fakeRedisClient{}, nil) _, err := l.Acquire(t.Context(), "k", -time.Second) - require.ErrorIs(t, err, distributedlock.ErrInvalidTTL) + must.ErrorIs(t, err, distributedlock.ErrInvalidTTL) }) T.Run("blocked by circuit breaker", func(t *testing.T) { @@ -295,8 +295,8 @@ func TestLocker_Acquire(T *testing.T) { l := newUnitLocker(t, &fakeRedisClient{}, cb) _, err := l.Acquire(t.Context(), "k", time.Minute) - require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - require.NotEmpty(t, cb.CannotProceedCalls()) + must.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + must.SliceNotEmpty(t, cb.CannotProceedCalls()) }) T.Run("SetNX backend error trips breaker", func(t *testing.T) { @@ -309,9 +309,9 @@ func TestLocker_Acquire(T *testing.T) { l := newUnitLocker(t, fc, cb) _, err := l.Acquire(t.Context(), "k", time.Minute) - require.Error(t, err) - require.NotEmpty(t, cb.CannotProceedCalls()) - require.NotEmpty(t, cb.FailedCalls()) + must.Error(t, err) + must.SliceNotEmpty(t, cb.CannotProceedCalls()) + must.SliceNotEmpty(t, cb.FailedCalls()) }) T.Run("contention does not fail breaker", func(t *testing.T) { @@ -324,9 +324,9 @@ func TestLocker_Acquire(T *testing.T) { l := newUnitLocker(t, fc, cb) _, err := l.Acquire(t.Context(), "k", time.Minute) - require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) - require.NotEmpty(t, cb.CannotProceedCalls()) - require.NotEmpty(t, cb.SucceededCalls()) + must.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) + must.SliceNotEmpty(t, cb.CannotProceedCalls()) + must.SliceNotEmpty(t, cb.SucceededCalls()) }) } @@ -339,9 +339,9 @@ func TestLocker_Release(T *testing.T) { l := newUnitLocker(t, fc, nil) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.NoError(t, h.Release(t.Context())) - assert.Equal(t, "lock:k", fc.lastEvalKey) + must.NoError(t, err) + must.NoError(t, h.Release(t.Context())) + test.EqOp(t, "lock:k", fc.lastEvalKey) }) T.Run("eval reports caller no longer holds lock", func(t *testing.T) { @@ -350,8 +350,8 @@ func TestLocker_Release(T *testing.T) { l := newUnitLocker(t, fc, nil) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, h.Release(t.Context()), distributedlock.ErrLockNotHeld) + must.NoError(t, err) + must.ErrorIs(t, h.Release(t.Context()), distributedlock.ErrLockNotHeld) }) T.Run("eval backend error trips breaker", func(t *testing.T) { @@ -366,13 +366,13 @@ func TestLocker_Release(T *testing.T) { l := newUnitLocker(t, fc, cb) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) + must.NoError(t, err) fc.evalErr = errors.New("eval boom") - require.Error(t, h.Release(t.Context())) - require.Len(t, cb.CannotProceedCalls(), 2) - require.Len(t, cb.SucceededCalls(), 1) - require.Len(t, cb.FailedCalls(), 1) + must.Error(t, h.Release(t.Context())) + must.SliceLen(t, 2, cb.CannotProceedCalls()) + must.SliceLen(t, 1, cb.SucceededCalls()) + must.SliceLen(t, 1, cb.FailedCalls()) }) T.Run("blocked by circuit breaker", func(t *testing.T) { @@ -389,10 +389,10 @@ func TestLocker_Release(T *testing.T) { l := newUnitLocker(t, fc, cb) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, h.Release(t.Context()), circuitbreaking.ErrCircuitBroken) - require.Len(t, cb.CannotProceedCalls(), 2) - require.Len(t, cb.SucceededCalls(), 1) + must.NoError(t, err) + must.ErrorIs(t, h.Release(t.Context()), circuitbreaking.ErrCircuitBroken) + must.SliceLen(t, 2, cb.CannotProceedCalls()) + must.SliceLen(t, 1, cb.SucceededCalls()) }) } @@ -405,10 +405,10 @@ func TestLocker_Refresh(T *testing.T) { l := newUnitLocker(t, fc, nil) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) + must.NoError(t, err) - require.NoError(t, h.Refresh(t.Context(), 5*time.Minute)) - assert.Equal(t, 5*time.Minute, h.TTL()) + must.NoError(t, h.Refresh(t.Context(), 5*time.Minute)) + test.EqOp(t, 5*time.Minute, h.TTL()) }) T.Run("rejects zero TTL", func(t *testing.T) { @@ -417,10 +417,10 @@ func TestLocker_Refresh(T *testing.T) { l := newUnitLocker(t, fc, nil) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, h.Refresh(t.Context(), 0), distributedlock.ErrInvalidTTL) + must.NoError(t, err) + must.ErrorIs(t, h.Refresh(t.Context(), 0), distributedlock.ErrInvalidTTL) // TTL must remain unchanged on failure. - assert.Equal(t, time.Minute, h.TTL()) + test.EqOp(t, time.Minute, h.TTL()) }) T.Run("eval reports caller no longer holds lock", func(t *testing.T) { @@ -429,10 +429,10 @@ func TestLocker_Refresh(T *testing.T) { l := newUnitLocker(t, fc, nil) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) + must.NoError(t, err) // Force the refresh script to "not held" by returning 0. - require.ErrorIs(t, h.Refresh(t.Context(), 2*time.Minute), distributedlock.ErrLockNotHeld) - assert.Equal(t, time.Minute, h.TTL()) + must.ErrorIs(t, h.Refresh(t.Context(), 2*time.Minute), distributedlock.ErrLockNotHeld) + test.EqOp(t, time.Minute, h.TTL()) }) T.Run("eval backend error trips breaker", func(t *testing.T) { @@ -446,13 +446,13 @@ func TestLocker_Refresh(T *testing.T) { l := newUnitLocker(t, fc, cb) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) + must.NoError(t, err) fc.evalErr = errors.New("eval boom") - require.Error(t, h.Refresh(t.Context(), 5*time.Minute)) - require.Len(t, cb.CannotProceedCalls(), 2) - require.Len(t, cb.SucceededCalls(), 1) - require.Len(t, cb.FailedCalls(), 1) + must.Error(t, h.Refresh(t.Context(), 5*time.Minute)) + must.SliceLen(t, 2, cb.CannotProceedCalls()) + must.SliceLen(t, 1, cb.SucceededCalls()) + must.SliceLen(t, 1, cb.FailedCalls()) }) T.Run("blocked by circuit breaker", func(t *testing.T) { @@ -469,10 +469,10 @@ func TestLocker_Refresh(T *testing.T) { l := newUnitLocker(t, fc, cb) h, err := l.Acquire(t.Context(), "k", time.Minute) - require.NoError(t, err) - require.ErrorIs(t, h.Refresh(t.Context(), time.Minute), circuitbreaking.ErrCircuitBroken) - require.Len(t, cb.CannotProceedCalls(), 2) - require.Len(t, cb.SucceededCalls(), 1) + must.NoError(t, err) + must.ErrorIs(t, h.Refresh(t.Context(), time.Minute), circuitbreaking.ErrCircuitBroken) + must.SliceLen(t, 2, cb.CannotProceedCalls()) + must.SliceLen(t, 1, cb.SucceededCalls()) }) } @@ -483,30 +483,30 @@ func TestLocker_PingClose(T *testing.T) { t.Parallel() fc := &fakeRedisClient{} l := newUnitLocker(t, fc, nil) - require.NoError(t, l.Ping(t.Context())) - assert.Equal(t, 1, fc.pingCalls) + must.NoError(t, l.Ping(t.Context())) + test.EqOp(t, 1, fc.pingCalls) }) T.Run("ping error", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{pingErr: errors.New("ping boom")} l := newUnitLocker(t, fc, nil) - require.Error(t, l.Ping(t.Context())) + must.Error(t, l.Ping(t.Context())) }) T.Run("close success", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{} l := newUnitLocker(t, fc, nil) - require.NoError(t, l.Close()) - assert.Equal(t, 1, fc.closeCalls) + must.NoError(t, l.Close()) + test.EqOp(t, 1, fc.closeCalls) }) T.Run("close error", func(t *testing.T) { t.Parallel() fc := &fakeRedisClient{closeErr: errors.New("close boom")} l := newUnitLocker(t, fc, nil) - require.Error(t, l.Close()) + must.Error(t, l.Close()) }) } @@ -517,9 +517,9 @@ func TestBuildRedisClient(T *testing.T) { t.Parallel() cfg := &Config{Addresses: []string{"localhost:6379"}} c := buildRedisClient(cfg) - require.NotNil(t, c) + must.NotNil(t, c) _, ok := c.(*redis.Client) - assert.True(t, ok) + test.True(t, ok) _ = c.Close() }) @@ -527,9 +527,9 @@ func TestBuildRedisClient(T *testing.T) { t.Parallel() cfg := &Config{Addresses: []string{"localhost:6379", "localhost:6380"}} c := buildRedisClient(cfg) - require.NotNil(t, c) + must.NotNil(t, c) _, ok := c.(*redis.ClusterClient) - assert.True(t, ok) + test.True(t, ok) _ = c.Close() }) } @@ -553,12 +553,12 @@ func TestRedisLocker_Container(T *testing.T) { key := "happy_" + identifiers.New() lock, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NotNil(t, lock) - assert.Equal(t, key, lock.Key()) - assert.Equal(t, time.Minute, lock.TTL()) + must.NoError(t, err) + must.NotNil(t, lock) + test.EqOp(t, key, lock.Key()) + test.EqOp(t, time.Minute, lock.TTL()) - require.NoError(t, lock.Release(ctx)) + must.NoError(t, lock.Release(ctx)) }) T.Run("Acquire contended returns ErrLockNotAcquired", func(t *testing.T) { @@ -568,25 +568,25 @@ func TestRedisLocker_Container(T *testing.T) { key := "contended_" + identifiers.New() first, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) + must.NoError(t, err) t.Cleanup(func() { _ = first.Release(ctx) }) _, err = l.Acquire(ctx, key, time.Minute) - require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) + must.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) }) T.Run("Acquire rejects empty key", func(t *testing.T) { t.Parallel() l := newTestLocker(t, cfg) _, err := l.Acquire(t.Context(), "", time.Minute) - require.ErrorIs(t, err, distributedlock.ErrEmptyKey) + must.ErrorIs(t, err, distributedlock.ErrEmptyKey) }) T.Run("Acquire rejects zero TTL", func(t *testing.T) { t.Parallel() l := newTestLocker(t, cfg) _, err := l.Acquire(t.Context(), "k", 0) - require.ErrorIs(t, err, distributedlock.ErrInvalidTTL) + must.ErrorIs(t, err, distributedlock.ErrInvalidTTL) }) T.Run("Release after expiration returns ErrLockNotHeld", func(t *testing.T) { @@ -596,10 +596,10 @@ func TestRedisLocker_Container(T *testing.T) { key := "expired_" + identifiers.New() lock, err := l.Acquire(ctx, key, 100*time.Millisecond) - require.NoError(t, err) + must.NoError(t, err) time.Sleep(250 * time.Millisecond) - require.ErrorIs(t, lock.Release(ctx), distributedlock.ErrLockNotHeld) + must.ErrorIs(t, lock.Release(ctx), distributedlock.ErrLockNotHeld) }) T.Run("Release wrong owner returns ErrLockNotHeld", func(t *testing.T) { @@ -609,14 +609,14 @@ func TestRedisLocker_Container(T *testing.T) { key := "stolen_" + identifiers.New() lock, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) + must.NoError(t, err) // Forge a different owner by overwriting the value out-of-band. direct := directRedisClient(t, cfg) t.Cleanup(func() { _ = direct.Close() }) - require.NoError(t, direct.Set(ctx, "lock:"+key, "someone-else", time.Minute).Err()) + must.NoError(t, direct.Set(ctx, "lock:"+key, "someone-else", time.Minute).Err()) - require.ErrorIs(t, lock.Release(ctx), distributedlock.ErrLockNotHeld) + must.ErrorIs(t, lock.Release(ctx), distributedlock.ErrLockNotHeld) }) T.Run("Refresh extends TTL", func(t *testing.T) { @@ -626,16 +626,16 @@ func TestRedisLocker_Container(T *testing.T) { key := "refresh_" + identifiers.New() lock, err := l.Acquire(ctx, key, 200*time.Millisecond) - require.NoError(t, err) - require.NoError(t, lock.Refresh(ctx, 5*time.Second)) + must.NoError(t, err) + must.NoError(t, lock.Refresh(ctx, 5*time.Second)) t.Cleanup(func() { _ = lock.Release(ctx) }) // Sleep past the original TTL; lock should still be held. time.Sleep(300 * time.Millisecond) _, err = l.Acquire(ctx, key, time.Minute) - require.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) - assert.Equal(t, 5*time.Second, lock.TTL()) + must.ErrorIs(t, err, distributedlock.ErrLockNotAcquired) + test.EqOp(t, 5*time.Second, lock.TTL()) }) T.Run("Refresh rejects invalid TTL", func(t *testing.T) { @@ -645,10 +645,10 @@ func TestRedisLocker_Container(T *testing.T) { key := "refreshinv_" + identifiers.New() lock, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) + must.NoError(t, err) t.Cleanup(func() { _ = lock.Release(ctx) }) - require.ErrorIs(t, lock.Refresh(ctx, 0), distributedlock.ErrInvalidTTL) + must.ErrorIs(t, lock.Refresh(ctx, 0), distributedlock.ErrInvalidTTL) }) T.Run("Double release returns ErrLockNotHeld on second call", func(t *testing.T) { @@ -658,9 +658,9 @@ func TestRedisLocker_Container(T *testing.T) { key := "double_" + identifiers.New() lock, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NoError(t, lock.Release(ctx)) - require.ErrorIs(t, lock.Release(ctx), distributedlock.ErrLockNotHeld) + must.NoError(t, err) + must.NoError(t, lock.Release(ctx)) + must.ErrorIs(t, lock.Release(ctx), distributedlock.ErrLockNotHeld) }) T.Run("Released lock can be reacquired", func(t *testing.T) { @@ -670,17 +670,17 @@ func TestRedisLocker_Container(T *testing.T) { key := "reacquire_" + identifiers.New() first, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NoError(t, first.Release(ctx)) + must.NoError(t, err) + must.NoError(t, first.Release(ctx)) second, err := l.Acquire(ctx, key, time.Minute) - require.NoError(t, err) - require.NoError(t, second.Release(ctx)) + must.NoError(t, err) + must.NoError(t, second.Release(ctx)) }) T.Run("Ping success", func(t *testing.T) { t.Parallel() l := newTestLocker(t, cfg) - require.NoError(t, l.Ping(t.Context())) + must.NoError(t, l.Ping(t.Context())) }) } diff --git a/email/config/config_test.go b/email/config/config_test.go index cd20b2f..e61e9b7 100644 --- a/email/config/config_test.go +++ b/email/config/config_test.go @@ -19,8 +19,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -35,7 +34,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Sendgrid: &sendgrid.Config{APIToken: t.Name()}, } - require.NoError(t, cfg.ValidateWithContext(ctx)) + must.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with invalid token", func(t *testing.T) { @@ -46,42 +45,42 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "sendgrid", } - require.Error(t, cfg.ValidateWithContext(ctx)) + must.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("mailgun provider requires config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderMailgun} - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("mailjet provider requires config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderMailjet} - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("resend provider requires config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderResend} - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("postmark provider requires config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderPostmark} - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("ses provider requires config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderSES} - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) } @@ -96,11 +95,11 @@ func TestConfig_BuildHermes(T *testing.T) { CompanyName: "Acme", LogoURL: "https://example.com/logo.png", }) - require.NotNil(t, h) - assert.Equal(t, "Acme", h.Product.Name) - assert.Equal(t, "https://example.com/logo.png", h.Product.Logo) - assert.Equal(t, "https://example.com", h.Product.Link) - assert.Contains(t, h.Product.Copyright, "Acme") + must.NotNil(t, h) + test.EqOp(t, "Acme", h.Product.Name) + test.EqOp(t, "https://example.com/logo.png", h.Product.Logo) + test.EqOp(t, "https://example.com", h.Product.Link) + test.StrContains(t, h.Product.Copyright, "Acme") }) T.Run("without branding", func(t *testing.T) { @@ -108,10 +107,10 @@ func TestConfig_BuildHermes(T *testing.T) { cfg := &Config{BaseURL: "https://example.com"} h := cfg.BuildHermes(nil) - require.NotNil(t, h) - assert.Empty(t, h.Product.Name) - assert.Empty(t, h.Product.Logo) - assert.Empty(t, h.Product.Copyright) + must.NotNil(t, h) + test.EqOp(t, "", h.Product.Name) + test.EqOp(t, "", h.Product.Logo) + test.EqOp(t, "", h.Product.Copyright) }) } @@ -123,7 +122,7 @@ func TestConfig_EnsureDefaults(T *testing.T) { cfg := &Config{} cfg.EnsureDefaults() - assert.NotEmpty(t, cfg.CircuitBreaker.Name) + test.NotEq(t, "", cfg.CircuitBreaker.Name) }) } @@ -153,8 +152,8 @@ func TestConfig_ProvideEmailer(T *testing.T) { } actual, err := cfg.ProvideEmailer(t.Context(), logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) } @@ -168,8 +167,8 @@ func TestConfig_ProvideEmailer(T *testing.T) { } actual, err := cfg.ProvideEmailer(t.Context(), logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with invalid provider", func(t *testing.T) { @@ -181,8 +180,8 @@ func TestConfig_ProvideEmailer(T *testing.T) { } actual, err := cfg.ProvideEmailer(t.Context(), logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) } @@ -203,8 +202,8 @@ func TestProvideEmailer(T *testing.T) { metrics.NewNoopMetricsProvider(), &http.Client{}, ) - require.NoError(t, err) - assert.NotNil(t, emailer) + must.NoError(t, err) + test.NotNil(t, emailer) }) T.Run("with sendgrid provider", func(t *testing.T) { @@ -224,8 +223,8 @@ func TestProvideEmailer(T *testing.T) { metrics.NewNoopMetricsProvider(), &http.Client{}, ) - require.NoError(t, err) - assert.NotNil(t, emailer) + must.NoError(t, err) + test.NotNil(t, emailer) }) T.Run("circuit breaker init failure", func(t *testing.T) { @@ -251,8 +250,8 @@ func TestProvideEmailer(T *testing.T) { mp, &http.Client{}, ) - require.Error(t, err) - assert.Nil(t, emailer) + must.Error(t, err) + test.Nil(t, emailer) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) diff --git a/email/config/do_test.go b/email/config/do_test.go index ddc1334..67a0111 100644 --- a/email/config/do_test.go +++ b/email/config/do_test.go @@ -11,8 +11,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterEmailer(T *testing.T) { @@ -38,7 +38,7 @@ func TestRegisterEmailer(T *testing.T) { RegisterEmailer(i) emailer, err := do.Invoke[email.Emailer](i) - require.NoError(t, err) - assert.NotNil(t, emailer) + must.NoError(t, err) + test.NotNil(t, emailer) }) } diff --git a/email/mailgun/mailgun_test.go b/email/mailgun/mailgun_test.go index 7d27fa5..8fbc726 100644 --- a/email/mailgun/mailgun_test.go +++ b/email/mailgun/mailgun_test.go @@ -12,7 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) const ( @@ -35,8 +35,8 @@ func TestNewMailgunEmailer(T *testing.T) { config := &Config{Domain: exampleDomain, PrivateAPIKey: t.Name()} client, err := NewMailgunEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, client) - require.NoError(t, err) + must.NotNil(t, client) + must.NoError(t, err) }) T.Run("with missing config", func(t *testing.T) { @@ -45,8 +45,8 @@ func TestNewMailgunEmailer(T *testing.T) { logger := logging.NewNoopLogger() client, err := NewMailgunEmailer(nil, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing config domain", func(t *testing.T) { @@ -57,8 +57,8 @@ func TestNewMailgunEmailer(T *testing.T) { config := &Config{PrivateAPIKey: t.Name()} client, err := NewMailgunEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing config private key", func(t *testing.T) { @@ -69,8 +69,8 @@ func TestNewMailgunEmailer(T *testing.T) { config := &Config{Domain: exampleDomain} client, err := NewMailgunEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing HTTP client", func(t *testing.T) { @@ -81,8 +81,8 @@ func TestNewMailgunEmailer(T *testing.T) { config := &Config{Domain: exampleDomain, PrivateAPIKey: t.Name()} client, err := NewMailgunEmailer(config, logger, tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) } @@ -104,8 +104,8 @@ func TestMailgunEmailer_SendEmail(T *testing.T) { cfg := &Config{Domain: exampleDomain, PrivateAPIKey: t.Name()} c, err := NewMailgunEmailer(cfg, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) c.client.SetAPIBase(ts.URL + "/v4") @@ -119,7 +119,7 @@ func TestMailgunEmailer_SendEmail(T *testing.T) { HTMLContent: t.Name(), } - require.NoError(t, c.SendEmail(ctx, details)) + must.NoError(t, c.SendEmail(ctx, details)) }) T.Run("with error executing request", func(t *testing.T) { @@ -137,8 +137,8 @@ func TestMailgunEmailer_SendEmail(T *testing.T) { cfg := &Config{Domain: exampleDomain, PrivateAPIKey: t.Name()} c, err := NewMailgunEmailer(cfg, logger, tracing.NewNoopTracerProvider(), client, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) ctx := t.Context() details := &email.OutboundEmailMessage{ ToAddress: t.Name(), @@ -150,7 +150,7 @@ func TestMailgunEmailer_SendEmail(T *testing.T) { } err = c.SendEmail(ctx, details) - require.Error(t, err) + must.Error(t, err) }) T.Run("with invalid response code", func(t *testing.T) { @@ -165,8 +165,8 @@ func TestMailgunEmailer_SendEmail(T *testing.T) { cfg := &Config{Domain: exampleDomain, PrivateAPIKey: t.Name()} c, err := NewMailgunEmailer(cfg, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) ctx := t.Context() details := &email.OutboundEmailMessage{ @@ -179,6 +179,6 @@ func TestMailgunEmailer_SendEmail(T *testing.T) { } err = c.SendEmail(ctx, details) - require.Error(t, err) + must.Error(t, err) }) } diff --git a/email/mailjet/mailjet_test.go b/email/mailjet/mailjet_test.go index f6fb5ed..6b48640 100644 --- a/email/mailjet/mailjet_test.go +++ b/email/mailjet/mailjet_test.go @@ -13,7 +13,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/mailjet/mailjet-apiv3-go/v4" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestNewMailjetEmailer(T *testing.T) { @@ -27,8 +27,8 @@ func TestNewMailjetEmailer(T *testing.T) { config := &Config{SecretKey: t.Name(), APIKey: t.Name()} client, err := NewMailjetEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, client) - require.NoError(t, err) + must.NotNil(t, client) + must.NoError(t, err) }) T.Run("with missing config", func(t *testing.T) { @@ -37,8 +37,8 @@ func TestNewMailjetEmailer(T *testing.T) { logger := logging.NewNoopLogger() client, err := NewMailjetEmailer(nil, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing config secret key", func(t *testing.T) { @@ -49,8 +49,8 @@ func TestNewMailjetEmailer(T *testing.T) { config := &Config{APIKey: t.Name()} client, err := NewMailjetEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing config public key", func(t *testing.T) { @@ -61,8 +61,8 @@ func TestNewMailjetEmailer(T *testing.T) { config := &Config{SecretKey: t.Name()} client, err := NewMailjetEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing HTTP client", func(t *testing.T) { @@ -73,8 +73,8 @@ func TestNewMailjetEmailer(T *testing.T) { config := &Config{SecretKey: t.Name(), APIKey: t.Name()} client, err := NewMailjetEmailer(config, logger, tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) } @@ -93,8 +93,8 @@ func TestMailjetEmailer_SendEmail(T *testing.T) { config := &Config{SecretKey: t.Name(), APIKey: t.Name()} c, err := NewMailjetEmailer(config, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) c.client.(*mailjet.Client).SetBaseURL(ts.URL + "/") @@ -108,7 +108,7 @@ func TestMailjetEmailer_SendEmail(T *testing.T) { HTMLContent: t.Name(), } - require.NoError(t, c.SendEmail(ctx, details)) + must.NoError(t, c.SendEmail(ctx, details)) }) T.Run("with error executing request", func(t *testing.T) { @@ -124,8 +124,8 @@ func TestMailjetEmailer_SendEmail(T *testing.T) { client := ts.Client() c, err := NewMailjetEmailer(config, logger, tracing.NewNoopTracerProvider(), client, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) c.client.(*mailjet.Client).SetBaseURL(ts.URL + "/") client.Timeout = time.Millisecond @@ -140,6 +140,6 @@ func TestMailjetEmailer_SendEmail(T *testing.T) { HTMLContent: t.Name(), } - require.Error(t, c.SendEmail(ctx, details)) + must.Error(t, c.SendEmail(ctx, details)) }) } diff --git a/email/noop/noop_test.go b/email/noop/noop_test.go index 37b6079..6f786d6 100644 --- a/email/noop/noop_test.go +++ b/email/noop/noop_test.go @@ -6,8 +6,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/email" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNewEmailer(T *testing.T) { @@ -17,8 +17,8 @@ func TestNewEmailer(T *testing.T) { t.Parallel() e, err := NewEmailer() - require.NoError(t, err) - assert.NotNil(t, e) + must.NoError(t, err) + test.NotNil(t, e) }) } @@ -29,23 +29,23 @@ func TestEmailer_SendEmail(T *testing.T) { t.Parallel() e, err := NewEmailer() - require.NoError(t, err) + must.NoError(t, err) err = e.SendEmail(context.Background(), &email.OutboundEmailMessage{ ToAddress: "test@example.com", Subject: "Test", HTMLContent: "

hello

", }) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("with nil message", func(t *testing.T) { t.Parallel() e, err := NewEmailer() - require.NoError(t, err) + must.NoError(t, err) err = e.SendEmail(context.Background(), nil) - assert.NoError(t, err) + test.NoError(t, err) }) } diff --git a/email/postmark/postmark_test.go b/email/postmark/postmark_test.go index 94e4a77..f76c3d3 100644 --- a/email/postmark/postmark_test.go +++ b/email/postmark/postmark_test.go @@ -12,7 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) type emailResponse struct { @@ -34,8 +34,8 @@ func TestNewPostmarkEmailer(T *testing.T) { config := &Config{ServerToken: t.Name()} client, err := NewPostmarkEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, client) - require.NoError(t, err) + must.NotNil(t, client) + must.NoError(t, err) }) T.Run("with missing config", func(t *testing.T) { @@ -44,8 +44,8 @@ func TestNewPostmarkEmailer(T *testing.T) { logger := logging.NewNoopLogger() client, err := NewPostmarkEmailer(nil, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing server token", func(t *testing.T) { @@ -56,8 +56,8 @@ func TestNewPostmarkEmailer(T *testing.T) { config := &Config{} client, err := NewPostmarkEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing HTTP client", func(t *testing.T) { @@ -68,8 +68,8 @@ func TestNewPostmarkEmailer(T *testing.T) { config := &Config{ServerToken: t.Name()} client, err := NewPostmarkEmailer(config, logger, tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) } @@ -94,8 +94,8 @@ func TestPostmarkEmailer_SendEmail(T *testing.T) { cfg := &Config{ServerToken: t.Name(), BaseURL: ts.URL} c, err := NewPostmarkEmailer(cfg, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) ctx := t.Context() details := &email.OutboundEmailMessage{ @@ -107,7 +107,7 @@ func TestPostmarkEmailer_SendEmail(T *testing.T) { HTMLContent: t.Name(), } - require.NoError(t, c.SendEmail(ctx, details)) + must.NoError(t, c.SendEmail(ctx, details)) }) T.Run("with error executing request", func(t *testing.T) { @@ -125,8 +125,8 @@ func TestPostmarkEmailer_SendEmail(T *testing.T) { cfg := &Config{ServerToken: t.Name(), BaseURL: ts.URL} c, err := NewPostmarkEmailer(cfg, logger, tracing.NewNoopTracerProvider(), client, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) ctx := t.Context() details := &email.OutboundEmailMessage{ @@ -139,7 +139,7 @@ func TestPostmarkEmailer_SendEmail(T *testing.T) { } err = c.SendEmail(ctx, details) - require.Error(t, err) + must.Error(t, err) }) T.Run("with invalid response code", func(t *testing.T) { @@ -154,8 +154,8 @@ func TestPostmarkEmailer_SendEmail(T *testing.T) { cfg := &Config{ServerToken: t.Name(), BaseURL: ts.URL} c, err := NewPostmarkEmailer(cfg, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) ctx := t.Context() details := &email.OutboundEmailMessage{ @@ -168,6 +168,6 @@ func TestPostmarkEmailer_SendEmail(T *testing.T) { } err = c.SendEmail(ctx, details) - require.Error(t, err) + must.Error(t, err) }) } diff --git a/email/resend/resend_test.go b/email/resend/resend_test.go index 8bcb343..3a5262a 100644 --- a/email/resend/resend_test.go +++ b/email/resend/resend_test.go @@ -13,7 +13,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) type sendEmailResponse struct { @@ -31,8 +31,8 @@ func TestNewResendEmailer(T *testing.T) { config := &Config{APIToken: t.Name()} client, err := NewResendEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, client) - require.NoError(t, err) + must.NotNil(t, client) + must.NoError(t, err) }) T.Run("with missing config", func(t *testing.T) { @@ -41,8 +41,8 @@ func TestNewResendEmailer(T *testing.T) { logger := logging.NewNoopLogger() client, err := NewResendEmailer(nil, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing config API token", func(t *testing.T) { @@ -53,8 +53,8 @@ func TestNewResendEmailer(T *testing.T) { config := &Config{} client, err := NewResendEmailer(config, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) T.Run("with missing HTTP client", func(t *testing.T) { @@ -65,8 +65,8 @@ func TestNewResendEmailer(T *testing.T) { config := &Config{APIToken: t.Name()} client, err := NewResendEmailer(config, logger, tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil) - require.Nil(t, client) - require.Error(t, err) + must.Nil(t, client) + must.Error(t, err) }) } @@ -85,11 +85,11 @@ func TestResendEmailer_SendEmail(T *testing.T) { cfg := &Config{APIToken: t.Name()} c, err := NewResendEmailer(cfg, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) baseURL, err := url.Parse(ts.URL + "/") - require.NoError(t, err) + must.NoError(t, err) c.client.BaseURL = baseURL ctx := t.Context() @@ -102,7 +102,7 @@ func TestResendEmailer_SendEmail(T *testing.T) { HTMLContent: t.Name(), } - require.NoError(t, c.SendEmail(ctx, details)) + must.NoError(t, c.SendEmail(ctx, details)) }) T.Run("with error executing request", func(t *testing.T) { @@ -120,11 +120,11 @@ func TestResendEmailer_SendEmail(T *testing.T) { cfg := &Config{APIToken: t.Name()} c, err := NewResendEmailer(cfg, logger, tracing.NewNoopTracerProvider(), client, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) baseURL, err := url.Parse(ts.URL + "/") - require.NoError(t, err) + must.NoError(t, err) c.client.BaseURL = baseURL ctx := t.Context() @@ -138,7 +138,7 @@ func TestResendEmailer_SendEmail(T *testing.T) { } err = c.SendEmail(ctx, details) - require.Error(t, err) + must.Error(t, err) }) T.Run("with invalid response code", func(t *testing.T) { @@ -153,11 +153,11 @@ func TestResendEmailer_SendEmail(T *testing.T) { cfg := &Config{APIToken: t.Name()} c, err := NewResendEmailer(cfg, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) baseURL, err := url.Parse(ts.URL + "/") - require.NoError(t, err) + must.NoError(t, err) c.client.BaseURL = baseURL ctx := t.Context() @@ -171,6 +171,6 @@ func TestResendEmailer_SendEmail(T *testing.T) { } err = c.SendEmail(ctx, details) - require.Error(t, err) + must.Error(t, err) }) } diff --git a/email/sendgrid/sendgrid_test.go b/email/sendgrid/sendgrid_test.go index 7a92077..19c283d 100644 --- a/email/sendgrid/sendgrid_test.go +++ b/email/sendgrid/sendgrid_test.go @@ -13,7 +13,7 @@ import ( "github.com/sendgrid/sendgrid-go" "github.com/sendgrid/sendgrid-go/helpers/mail" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestNewSendGridEmailer(T *testing.T) { @@ -25,8 +25,8 @@ func TestNewSendGridEmailer(T *testing.T) { logger := logging.NewNoopLogger() client, err := NewSendGridEmailer(&Config{APIToken: t.Name()}, logger, tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, client) - require.NoError(t, err) + must.NotNil(t, client) + must.NoError(t, err) }) } @@ -43,8 +43,8 @@ func TestSendGridEmailer_SendEmail(T *testing.T) { })) c, err := NewSendGridEmailer(&Config{APIToken: t.Name()}, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) c.client.BaseURL = ts.URL @@ -58,7 +58,7 @@ func TestSendGridEmailer_SendEmail(T *testing.T) { HTMLContent: t.Name(), } - require.NoError(t, c.SendEmail(ctx, details)) + must.NoError(t, c.SendEmail(ctx, details)) }) T.Run("with error executing request", func(t *testing.T) { @@ -73,8 +73,8 @@ func TestSendGridEmailer_SendEmail(T *testing.T) { client.Timeout = time.Millisecond c, err := NewSendGridEmailer(&Config{APIToken: t.Name()}, logger, tracing.NewNoopTracerProvider(), client, cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) c.client.BaseURL = ts.URL @@ -89,7 +89,7 @@ func TestSendGridEmailer_SendEmail(T *testing.T) { } err = c.SendEmail(ctx, details) - require.Error(t, err) + must.Error(t, err) }) T.Run("with invalid response code", func(t *testing.T) { @@ -102,8 +102,8 @@ func TestSendGridEmailer_SendEmail(T *testing.T) { })) c, err := NewSendGridEmailer(&Config{APIToken: t.Name()}, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) c.client.BaseURL = ts.URL @@ -118,7 +118,7 @@ func TestSendGridEmailer_SendEmail(T *testing.T) { } err = c.SendEmail(ctx, details) - require.Error(t, err) + must.Error(t, err) }) } @@ -135,8 +135,8 @@ func TestSendGridEmailer_sendDynamicTemplateEmail(T *testing.T) { })) c, err := NewSendGridEmailer(&Config{APIToken: t.Name()}, logger, tracing.NewNoopTracerProvider(), ts.Client(), cbnoop.NewCircuitBreaker(), nil) - require.NotNil(t, c) - require.NoError(t, err) + must.NotNil(t, c) + must.NoError(t, err) c.client.BaseURL = ts.URL @@ -146,6 +146,6 @@ func TestSendGridEmailer_sendDynamicTemplateEmail(T *testing.T) { request := sendgrid.GetRequest(c.config.APIToken, "/v3/mail/send", ts.URL) - require.NoError(t, c.sendDynamicTemplateEmail(ctx, to, from, t.Name(), map[string]any{"things": "stuff"}, request)) + must.NoError(t, c.sendDynamicTemplateEmail(ctx, to, from, t.Name(), map[string]any{"things": "stuff"}, request)) }) } diff --git a/email/ses/config_test.go b/email/ses/config_test.go index c9c7d69..b9d7808 100644 --- a/email/ses/config_test.go +++ b/email/ses/config_test.go @@ -3,8 +3,8 @@ package ses import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -17,7 +17,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Region: "us-east-1", } - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("with missing region", func(t *testing.T) { @@ -26,7 +26,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} err := cfg.ValidateWithContext(t.Context()) - require.Error(t, err) - assert.Contains(t, err.Error(), "region") + must.Error(t, err) + test.StrContains(t, err.Error(), "region") }) } diff --git a/email/ses/ses_test.go b/email/ses/ses_test.go index 8a29368..84e8c24 100644 --- a/email/ses/ses_test.go +++ b/email/ses/ses_test.go @@ -12,8 +12,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/aws/aws-sdk-go-v2/service/sesv2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type mockSESClient struct { @@ -35,17 +35,17 @@ func TestNewSESEmailer(T *testing.T) { mock := &mockSESClient{} client, err := NewSESEmailer(t.Context(), cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil, mock) - require.NoError(t, err) - require.NotNil(t, client) + must.NoError(t, err) + must.NotNil(t, client) }) T.Run("with nil config", func(t *testing.T) { t.Parallel() client, err := NewSESEmailer(t.Context(), nil, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil, &mockSESClient{}) - require.Error(t, err) - assert.Nil(t, client) - assert.ErrorIs(t, err, ErrNilConfig) + must.Error(t, err) + test.Nil(t, client) + test.ErrorIs(t, err, ErrNilConfig) }) T.Run("with empty region", func(t *testing.T) { @@ -54,9 +54,9 @@ func TestNewSESEmailer(T *testing.T) { cfg := &Config{} client, err := NewSESEmailer(t.Context(), cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil, &mockSESClient{}) - require.Error(t, err) - assert.Nil(t, client) - assert.ErrorIs(t, err, ErrEmptyRegion) + must.Error(t, err) + test.Nil(t, client) + test.ErrorIs(t, err, ErrEmptyRegion) }) T.Run("with nil HTTP client and nil SES client", func(t *testing.T) { @@ -65,9 +65,9 @@ func TestNewSESEmailer(T *testing.T) { cfg := &Config{Region: "us-east-1"} client, err := NewSESEmailer(t.Context(), cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil, nil) - require.Error(t, err) - assert.Nil(t, client) - assert.ErrorIs(t, err, ErrNilHTTPClient) + must.Error(t, err) + test.Nil(t, client) + test.ErrorIs(t, err, ErrNilHTTPClient) }) T.Run("with HTTP client and nil SES client", func(t *testing.T) { @@ -76,8 +76,8 @@ func TestNewSESEmailer(T *testing.T) { cfg := &Config{Region: "us-east-1"} client, err := NewSESEmailer(t.Context(), cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), &http.Client{}, cbnoop.NewCircuitBreaker(), nil, nil) - require.NoError(t, err) - require.NotNil(t, client) + must.NoError(t, err) + must.NotNil(t, client) }) } @@ -91,7 +91,7 @@ func TestEmailer_SendEmail(T *testing.T) { cfg := &Config{Region: "us-east-1"} e, err := NewSESEmailer(t.Context(), cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil, mock) - require.NoError(t, err) + must.NoError(t, err) details := &email.OutboundEmailMessage{ ToAddress: "to@example.com", @@ -102,7 +102,7 @@ func TestEmailer_SendEmail(T *testing.T) { HTMLContent: t.Name(), } - require.NoError(t, e.SendEmail(t.Context(), details)) + must.NoError(t, e.SendEmail(t.Context(), details)) }) T.Run("without names", func(t *testing.T) { @@ -112,7 +112,7 @@ func TestEmailer_SendEmail(T *testing.T) { cfg := &Config{Region: "us-east-1"} e, err := NewSESEmailer(t.Context(), cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil, mock) - require.NoError(t, err) + must.NoError(t, err) details := &email.OutboundEmailMessage{ ToAddress: "to@example.com", @@ -121,7 +121,7 @@ func TestEmailer_SendEmail(T *testing.T) { HTMLContent: t.Name(), } - require.NoError(t, e.SendEmail(t.Context(), details)) + must.NoError(t, e.SendEmail(t.Context(), details)) }) T.Run("with error from SES", func(t *testing.T) { @@ -131,7 +131,7 @@ func TestEmailer_SendEmail(T *testing.T) { cfg := &Config{Region: "us-east-1"} e, err := NewSESEmailer(t.Context(), cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil, mock) - require.NoError(t, err) + must.NoError(t, err) details := &email.OutboundEmailMessage{ ToAddress: "to@example.com", @@ -143,7 +143,7 @@ func TestEmailer_SendEmail(T *testing.T) { } err = e.SendEmail(t.Context(), details) - require.Error(t, err) + must.Error(t, err) }) T.Run("with broken circuit breaker", func(t *testing.T) { @@ -153,7 +153,7 @@ func TestEmailer_SendEmail(T *testing.T) { cfg := &Config{Region: "us-east-1"} e, err := NewSESEmailer(t.Context(), cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), nil, mock) - require.NoError(t, err) + must.NoError(t, err) e.circuitBreaker = &brokenCircuitBreaker{} @@ -167,7 +167,7 @@ func TestEmailer_SendEmail(T *testing.T) { } err = e.SendEmail(t.Context(), details) - require.Error(t, err) + must.Error(t, err) }) } diff --git a/embeddings/cohere/cohere_test.go b/embeddings/cohere/cohere_test.go index 5380030..4bfe9f5 100644 --- a/embeddings/cohere/cohere_test.go +++ b/embeddings/cohere/cohere_test.go @@ -14,8 +14,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type roundTripFunc func(*http.Request) (*http.Response, error) @@ -40,24 +40,24 @@ func TestNewEmbedder(T *testing.T) { t.Parallel() emb, err := NewEmbedder(t.Context(), nil, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.Error(t, err) - require.Nil(t, emb) + must.Error(t, err) + must.Nil(t, emb) }) T.Run("with missing API key", func(t *testing.T) { t.Parallel() emb, err := NewEmbedder(t.Context(), &Config{}, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.Error(t, err) - require.Nil(t, emb) + must.Error(t, err) + must.Nil(t, emb) }) T.Run("standard", func(t *testing.T) { t.Parallel() emb, err := NewEmbedder(t.Context(), &Config{APIKey: "test-key"}, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) - require.NotNil(t, emb) + must.NoError(t, err) + must.NotNil(t, emb) }) T.Run("with timeout", func(t *testing.T) { @@ -67,8 +67,8 @@ func TestNewEmbedder(T *testing.T) { APIKey: "test-key", Timeout: 5 * time.Second, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) - require.NotNil(t, emb) + must.NoError(t, err) + must.NotNil(t, emb) }) } @@ -89,11 +89,11 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/v2/embed", r.URL.Path) - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + must.EqOp(t, "/v2/embed", r.URL.Path) + must.EqOp(t, http.MethodPost, r.Method) + must.EqOp(t, "Bearer test-key", r.Header.Get("Authorization")) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(cohereEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(cohereEmbeddingResponse)) })) t.Cleanup(ts.Close) @@ -101,21 +101,21 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello world", }) - require.NoError(t, err) - require.NotNil(t, result) - assert.Equal(t, "hello world", result.SourceText) - assert.Equal(t, "embed-english-v3.0", result.Model) - assert.Equal(t, "cohere", result.Provider) - assert.Equal(t, 5, result.Dimensions) - assert.Len(t, result.Vector, 5) - assert.False(t, result.GeneratedAt.IsZero()) + must.NoError(t, err) + must.NotNil(t, result) + test.EqOp(t, "hello world", result.SourceText) + test.EqOp(t, "embed-english-v3.0", result.Model) + test.EqOp(t, "cohere", result.Provider) + test.EqOp(t, 5, result.Dimensions) + test.SliceLen(t, 5, result.Vector) + test.False(t, result.GeneratedAt.IsZero()) }) T.Run("uses input model override", func(t *testing.T) { @@ -123,10 +123,10 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var reqBody embeddingRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) - require.Equal(t, "embed-multilingual-v3.0", reqBody.Model) + must.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + must.EqOp(t, "embed-multilingual-v3.0", reqBody.Model) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(cohereEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(cohereEmbeddingResponse)) })) t.Cleanup(ts.Close) @@ -135,7 +135,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { BaseURL: ts.URL, DefaultModel: "embed-english-v3.0", }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ @@ -143,8 +143,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { Model: "embed-multilingual-v3.0", }) - require.NoError(t, err) - require.NotNil(t, result) + must.NoError(t, err) + must.NotNil(t, result) }) T.Run("with non-200 response", func(t *testing.T) { @@ -160,15 +160,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "bad-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with malformed JSON response", func(t *testing.T) { @@ -184,15 +184,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with empty embeddings response", func(t *testing.T) { @@ -200,7 +200,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + must.NoError(t, json.NewEncoder(w).Encode(map[string]any{ "embeddings": map[string]any{ "float": [][]float64{}, }, @@ -212,15 +212,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with connection error", func(t *testing.T) { @@ -233,15 +233,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("uses config default model", func(t *testing.T) { @@ -249,10 +249,10 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var reqBody embeddingRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) - require.Equal(t, "embed-multilingual-v3.0", reqBody.Model) + must.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + must.EqOp(t, "embed-multilingual-v3.0", reqBody.Model) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(cohereEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(cohereEmbeddingResponse)) })) t.Cleanup(ts.Close) @@ -261,16 +261,16 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { BaseURL: ts.URL, DefaultModel: "embed-multilingual-v3.0", }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.NoError(t, err) - require.NotNil(t, result) - assert.Equal(t, "embed-multilingual-v3.0", result.Model) + must.NoError(t, err) + must.NotNil(t, result) + test.EqOp(t, "embed-multilingual-v3.0", result.Model) }) T.Run("with default base URL", func(t *testing.T) { @@ -282,7 +282,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { tracer: tracing.NewTracerForTest("test"), client: &http.Client{ Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { - assert.Contains(t, r.URL.String(), defaultBaseURL) + test.StrContains(t, r.URL.String(), defaultBaseURL) body := `{"embeddings":{"float":[[0.1,0.2]]}}` return &http.Response{ StatusCode: http.StatusOK, @@ -294,8 +294,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.NoError(t, err) - require.NotNil(t, result) + must.NoError(t, err) + must.NotNil(t, result) }) T.Run("with request building error", func(t *testing.T) { @@ -310,8 +310,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with response body close error", func(t *testing.T) { @@ -334,8 +334,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.NoError(t, err) - require.NotNil(t, result) + must.NoError(t, err) + must.NotNil(t, result) }) T.Run("with error reading error response body", func(t *testing.T) { @@ -357,7 +357,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) } diff --git a/embeddings/config/config_test.go b/embeddings/config/config_test.go index 14ceb57..4334624 100644 --- a/embeddings/config/config_test.go +++ b/embeddings/config/config_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -20,14 +20,14 @@ func TestConfig_ValidateWithContext(T *testing.T) { t.Parallel() cfg := &Config{Provider: ""} - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("with invalid provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: "invalid"} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("openai provider with config", func(t *testing.T) { @@ -37,14 +37,14 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderOpenAI, OpenAI: &openai.Config{APIKey: t.Name()}, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("openai provider requires config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderOpenAI} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("ollama provider with config", func(t *testing.T) { @@ -54,14 +54,14 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderOllama, Ollama: &ollama.Config{}, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("ollama provider requires config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderOllama} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("cohere provider with config", func(t *testing.T) { @@ -71,14 +71,14 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderCohere, Cohere: &cohere.Config{APIKey: t.Name()}, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("cohere provider requires config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderCohere} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } @@ -93,8 +93,8 @@ func TestConfig_ProvideEmbedder_Empty(T *testing.T) { tracer := tracing.NewTracerForTest("test") embedder, err := cfg.ProvideEmbedder(t.Context(), logger, tracer) - require.NoError(t, err) - require.NotNil(t, embedder, "expected non-nil embedder (noop)") + must.NoError(t, err) + must.NotNil(t, embedder, must.Sprintf("expected non-nil embedder (noop)")) }) } @@ -114,8 +114,8 @@ func TestConfig_ProvideEmbedder_OpenAI(T *testing.T) { tracer := tracing.NewTracerForTest("test") embedder, err := cfg.ProvideEmbedder(t.Context(), logger, tracer) - require.NoError(t, err) - require.NotNil(t, embedder) + must.NoError(t, err) + must.NotNil(t, embedder) }) } @@ -133,8 +133,8 @@ func TestConfig_ProvideEmbedder_Ollama(T *testing.T) { tracer := tracing.NewTracerForTest("test") embedder, err := cfg.ProvideEmbedder(t.Context(), logger, tracer) - require.NoError(t, err) - require.NotNil(t, embedder) + must.NoError(t, err) + must.NotNil(t, embedder) }) } @@ -154,7 +154,7 @@ func TestConfig_ProvideEmbedder_Cohere(T *testing.T) { tracer := tracing.NewTracerForTest("test") embedder, err := cfg.ProvideEmbedder(t.Context(), logger, tracer) - require.NoError(t, err) - require.NotNil(t, embedder) + must.NoError(t, err) + must.NotNil(t, embedder) }) } diff --git a/embeddings/config/do_test.go b/embeddings/config/do_test.go index 4351cea..20e9bde 100644 --- a/embeddings/config/do_test.go +++ b/embeddings/config/do_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterEmbedder(T *testing.T) { @@ -26,7 +26,7 @@ func TestRegisterEmbedder(T *testing.T) { RegisterEmbedder(i) embedder, err := do.Invoke[embeddings.Embedder](i) - require.NoError(t, err) - assert.NotNil(t, embedder) + must.NoError(t, err) + test.NotNil(t, embedder) }) } diff --git a/embeddings/embeddings_test.go b/embeddings/embeddings_test.go index 275d71d..ad9a7f2 100644 --- a/embeddings/embeddings_test.go +++ b/embeddings/embeddings_test.go @@ -3,8 +3,8 @@ package embeddings import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNoopEmbedder_GenerateEmbedding(T *testing.T) { @@ -20,13 +20,13 @@ func TestNoopEmbedder_GenerateEmbedding(T *testing.T) { Content: "hello world", }) - require.NoError(t, err) - require.NotNil(t, result) - assert.Equal(t, "hello world", result.SourceText) - assert.Equal(t, "noop", result.Model) - assert.Equal(t, "noop", result.Provider) - assert.Equal(t, 0, result.Dimensions) - assert.Empty(t, result.Vector) - assert.False(t, result.GeneratedAt.IsZero()) + must.NoError(t, err) + must.NotNil(t, result) + test.EqOp(t, "hello world", result.SourceText) + test.EqOp(t, "noop", result.Model) + test.EqOp(t, "noop", result.Provider) + test.EqOp(t, 0, result.Dimensions) + test.SliceEmpty(t, result.Vector) + test.False(t, result.GeneratedAt.IsZero()) }) } diff --git a/embeddings/ollama/ollama_test.go b/embeddings/ollama/ollama_test.go index d15885f..6e9d293 100644 --- a/embeddings/ollama/ollama_test.go +++ b/embeddings/ollama/ollama_test.go @@ -14,8 +14,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type roundTripFunc func(*http.Request) (*http.Response, error) @@ -40,16 +40,16 @@ func TestNewEmbedder(T *testing.T) { t.Parallel() emb, err := NewEmbedder(t.Context(), nil, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.Error(t, err) - require.Nil(t, emb) + must.Error(t, err) + must.Nil(t, emb) }) T.Run("standard", func(t *testing.T) { t.Parallel() emb, err := NewEmbedder(t.Context(), &Config{}, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) - require.NotNil(t, emb) + must.NoError(t, err) + must.NotNil(t, emb) }) T.Run("with custom base URL", func(t *testing.T) { @@ -58,8 +58,8 @@ func TestNewEmbedder(T *testing.T) { emb, err := NewEmbedder(t.Context(), &Config{ BaseURL: "http://custom:11434", }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) - require.NotNil(t, emb) + must.NoError(t, err) + must.NotNil(t, emb) }) T.Run("with timeout", func(t *testing.T) { @@ -68,8 +68,8 @@ func TestNewEmbedder(T *testing.T) { emb, err := NewEmbedder(t.Context(), &Config{ Timeout: 5 * time.Second, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) - require.NotNil(t, emb) + must.NoError(t, err) + must.NotNil(t, emb) }) } @@ -86,31 +86,31 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/api/embed", r.URL.Path) - require.Equal(t, http.MethodPost, r.Method) + must.EqOp(t, "/api/embed", r.URL.Path) + must.EqOp(t, http.MethodPost, r.Method) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(ollamaEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(ollamaEmbeddingResponse)) })) t.Cleanup(ts.Close) emb, err := NewEmbedder(t.Context(), &Config{ BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello world", }) - require.NoError(t, err) - require.NotNil(t, result) - assert.Equal(t, "hello world", result.SourceText) - assert.Equal(t, "nomic-embed-text", result.Model) - assert.Equal(t, "ollama", result.Provider) - assert.Equal(t, 4, result.Dimensions) - assert.Len(t, result.Vector, 4) - assert.False(t, result.GeneratedAt.IsZero()) + must.NoError(t, err) + must.NotNil(t, result) + test.EqOp(t, "hello world", result.SourceText) + test.EqOp(t, "nomic-embed-text", result.Model) + test.EqOp(t, "ollama", result.Provider) + test.EqOp(t, 4, result.Dimensions) + test.SliceLen(t, 4, result.Vector) + test.False(t, result.GeneratedAt.IsZero()) }) T.Run("uses input model override", func(t *testing.T) { @@ -118,10 +118,10 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var reqBody embeddingRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) - require.Equal(t, "mxbai-embed-large", reqBody.Model) + must.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + must.EqOp(t, "mxbai-embed-large", reqBody.Model) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(ollamaEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(ollamaEmbeddingResponse)) })) t.Cleanup(ts.Close) @@ -129,7 +129,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { BaseURL: ts.URL, DefaultModel: "nomic-embed-text", }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ @@ -137,8 +137,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { Model: "mxbai-embed-large", }) - require.NoError(t, err) - require.NotNil(t, result) + must.NoError(t, err) + must.NotNil(t, result) }) T.Run("with non-200 response", func(t *testing.T) { @@ -153,15 +153,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { emb, err := NewEmbedder(t.Context(), &Config{ BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with malformed JSON response", func(t *testing.T) { @@ -176,15 +176,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { emb, err := NewEmbedder(t.Context(), &Config{ BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with empty embeddings response", func(t *testing.T) { @@ -192,7 +192,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + must.NoError(t, json.NewEncoder(w).Encode(map[string]any{ "embeddings": [][]float64{}, })) })) @@ -201,15 +201,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { emb, err := NewEmbedder(t.Context(), &Config{ BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with connection error", func(t *testing.T) { @@ -221,15 +221,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { emb, err := NewEmbedder(t.Context(), &Config{ BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("uses config default model", func(t *testing.T) { @@ -237,10 +237,10 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var reqBody embeddingRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) - require.Equal(t, "mxbai-embed-large", reqBody.Model) + must.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + must.EqOp(t, "mxbai-embed-large", reqBody.Model) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(ollamaEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(ollamaEmbeddingResponse)) })) t.Cleanup(ts.Close) @@ -248,16 +248,16 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { BaseURL: ts.URL, DefaultModel: "mxbai-embed-large", }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.NoError(t, err) - require.NotNil(t, result) - assert.Equal(t, "mxbai-embed-large", result.Model) + must.NoError(t, err) + must.NotNil(t, result) + test.EqOp(t, "mxbai-embed-large", result.Model) }) T.Run("with request building error", func(t *testing.T) { @@ -272,8 +272,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with response body close error", func(t *testing.T) { @@ -296,8 +296,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.NoError(t, err) - require.NotNil(t, result) + must.NoError(t, err) + must.NotNil(t, result) }) T.Run("with error reading error response body", func(t *testing.T) { @@ -319,7 +319,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) } diff --git a/embeddings/openai/openai_test.go b/embeddings/openai/openai_test.go index b9b5e6d..2601828 100644 --- a/embeddings/openai/openai_test.go +++ b/embeddings/openai/openai_test.go @@ -14,8 +14,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type roundTripFunc func(*http.Request) (*http.Response, error) @@ -40,24 +40,24 @@ func TestNewEmbedder(T *testing.T) { t.Parallel() emb, err := NewEmbedder(t.Context(), nil, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.Error(t, err) - require.Nil(t, emb) + must.Error(t, err) + must.Nil(t, emb) }) T.Run("with missing API key", func(t *testing.T) { t.Parallel() emb, err := NewEmbedder(t.Context(), &Config{}, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.Error(t, err) - require.Nil(t, emb) + must.Error(t, err) + must.Nil(t, emb) }) T.Run("standard", func(t *testing.T) { t.Parallel() emb, err := NewEmbedder(t.Context(), &Config{APIKey: "test-key"}, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) - require.NotNil(t, emb) + must.NoError(t, err) + must.NotNil(t, emb) }) T.Run("with timeout", func(t *testing.T) { @@ -67,8 +67,8 @@ func TestNewEmbedder(T *testing.T) { APIKey: "test-key", Timeout: 5 * time.Second, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) - require.NotNil(t, emb) + must.NoError(t, err) + must.NotNil(t, emb) }) } @@ -95,11 +95,11 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/v1/embeddings", r.URL.Path) - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + must.EqOp(t, "/v1/embeddings", r.URL.Path) + must.EqOp(t, http.MethodPost, r.Method) + must.EqOp(t, "Bearer test-key", r.Header.Get("Authorization")) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(openAIEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(openAIEmbeddingResponse)) })) t.Cleanup(ts.Close) @@ -107,21 +107,21 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello world", }) - require.NoError(t, err) - require.NotNil(t, result) - assert.Equal(t, "hello world", result.SourceText) - assert.Equal(t, "text-embedding-3-small", result.Model) - assert.Equal(t, "openai", result.Provider) - assert.Equal(t, 3, result.Dimensions) - assert.Len(t, result.Vector, 3) - assert.False(t, result.GeneratedAt.IsZero()) + must.NoError(t, err) + must.NotNil(t, result) + test.EqOp(t, "hello world", result.SourceText) + test.EqOp(t, "text-embedding-3-small", result.Model) + test.EqOp(t, "openai", result.Provider) + test.EqOp(t, 3, result.Dimensions) + test.SliceLen(t, 3, result.Vector) + test.False(t, result.GeneratedAt.IsZero()) }) T.Run("uses input model override", func(t *testing.T) { @@ -129,10 +129,10 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var reqBody embeddingRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) - require.Equal(t, "text-embedding-3-large", reqBody.Model) + must.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + must.EqOp(t, "text-embedding-3-large", reqBody.Model) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(openAIEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(openAIEmbeddingResponse)) })) t.Cleanup(ts.Close) @@ -141,7 +141,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { BaseURL: ts.URL, DefaultModel: "text-embedding-3-small", }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ @@ -149,8 +149,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { Model: "text-embedding-3-large", }) - require.NoError(t, err) - require.NotNil(t, result) + must.NoError(t, err) + must.NotNil(t, result) }) T.Run("with non-200 response", func(t *testing.T) { @@ -166,15 +166,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with malformed JSON response", func(t *testing.T) { @@ -190,15 +190,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with empty data response", func(t *testing.T) { @@ -206,7 +206,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]any{ + must.NoError(t, json.NewEncoder(w).Encode(map[string]any{ "data": []map[string]any{}, })) })) @@ -216,15 +216,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with connection error", func(t *testing.T) { @@ -237,15 +237,15 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("uses config default model", func(t *testing.T) { @@ -253,10 +253,10 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var reqBody embeddingRequest - require.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) - require.Equal(t, "text-embedding-3-large", reqBody.Model) + must.NoError(t, json.NewDecoder(r.Body).Decode(&reqBody)) + must.EqOp(t, "text-embedding-3-large", reqBody.Model) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(openAIEmbeddingResponse)) + must.NoError(t, json.NewEncoder(w).Encode(openAIEmbeddingResponse)) })) t.Cleanup(ts.Close) @@ -265,16 +265,16 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { BaseURL: ts.URL, DefaultModel: "text-embedding-3-large", }, logging.NewNoopLogger(), tracing.NewTracerForTest("test")) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := emb.GenerateEmbedding(ctx, &embeddings.Input{ Content: "hello", }) - require.NoError(t, err) - require.NotNil(t, result) - assert.Equal(t, "text-embedding-3-large", result.Model) + must.NoError(t, err) + must.NotNil(t, result) + test.EqOp(t, "text-embedding-3-large", result.Model) }) T.Run("with default base URL", func(t *testing.T) { @@ -286,7 +286,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { tracer: tracing.NewTracerForTest("test"), client: &http.Client{ Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { - assert.Contains(t, r.URL.String(), defaultBaseURL) + test.StrContains(t, r.URL.String(), defaultBaseURL) body := `{"data":[{"embedding":[0.1,0.2]}]}` return &http.Response{ StatusCode: http.StatusOK, @@ -298,8 +298,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.NoError(t, err) - require.NotNil(t, result) + must.NoError(t, err) + must.NotNil(t, result) }) T.Run("with request building error", func(t *testing.T) { @@ -314,8 +314,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) T.Run("with response body close error", func(t *testing.T) { @@ -338,8 +338,8 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.NoError(t, err) - require.NotNil(t, result) + must.NoError(t, err) + must.NotNil(t, result) }) T.Run("with error reading error response body", func(t *testing.T) { @@ -361,7 +361,7 @@ func TestEmbedder_GenerateEmbedding(T *testing.T) { result, err := e.GenerateEmbedding(t.Context(), &embeddings.Input{Content: "hello"}) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) } diff --git a/encoding/client_encoder_test.go b/encoding/client_encoder_test.go index 797f2cd..6759c2a 100644 --- a/encoding/client_encoder_test.go +++ b/encoding/client_encoder_test.go @@ -12,8 +12,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/keith-turner/ecoji/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestProvideClientEncoder(T *testing.T) { @@ -22,7 +22,7 @@ func TestProvideClientEncoder(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON)) + test.NotNil(t, ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON)) }) } @@ -65,8 +65,8 @@ func Test_clientEncoder_Unmarshal(T *testing.T) { expected := &example{Name: "name"} actual := &example{} - assert.NoError(t, e.Unmarshal(ctx, []byte(tc.expected), &actual)) - assert.Equal(t, expected, actual) + test.NoError(t, e.Unmarshal(ctx, []byte(tc.expected), &actual)) + test.Eq(t, expected, actual) }) } @@ -78,8 +78,8 @@ func Test_clientEncoder_Unmarshal(T *testing.T) { actual := &example{} - assert.Error(t, e.Unmarshal(ctx, []byte(`{"name" `), &actual)) - assert.Empty(t, actual.Name) + test.Error(t, e.Unmarshal(ctx, []byte(`{"name" `), &actual)) + test.EqOp(t, "", actual.Name) }) } @@ -95,7 +95,7 @@ func Test_clientEncoder_Encode(T *testing.T) { res := httptest.NewRecorder() - assert.NoError(t, e.Encode(ctx, res, &example{Name: t.Name()})) + test.NoError(t, e.Encode(ctx, res, &example{Name: t.Name()})) }) } @@ -112,7 +112,7 @@ func Test_clientEncoder_Encode(T *testing.T) { }, } - assert.Error(t, e.Encode(ctx, mw, &example{Name: t.Name()})) + test.Error(t, e.Encode(ctx, mw, &example{Name: t.Name()})) }) } @@ -122,7 +122,7 @@ func Test_clientEncoder_Encode(T *testing.T) { ctx := t.Context() e := ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON) - assert.Error(t, e.Encode(ctx, nil, &broken{Name: json.Number(t.Name())})) + test.Error(t, e.Encode(ctx, nil, &broken{Name: json.Number(t.Name())})) }) T.Run("with emoji encode error", func(t *testing.T) { @@ -132,7 +132,7 @@ func Test_clientEncoder_Encode(T *testing.T) { e := ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeEmoji) var b bytes.Buffer - assert.Error(t, e.Encode(ctx, &b, make(chan int))) + test.Error(t, e.Encode(ctx, &b, make(chan int))) }) } @@ -147,8 +147,8 @@ func Test_clientEncoder_EncodeReader(T *testing.T) { e := ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ct) actual, err := e.EncodeReader(ctx, &example{Name: t.Name()}) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) } @@ -159,8 +159,8 @@ func Test_clientEncoder_EncodeReader(T *testing.T) { e := ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON) actual, err := e.EncodeReader(ctx, &broken{Name: json.Number(t.Name())}) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) } @@ -171,7 +171,7 @@ func Test_marshalEmoji(T *testing.T) { t.Parallel() _, err := marshalEmoji(make(chan int)) - assert.Error(t, err) + test.Error(t, err) }) } @@ -182,17 +182,17 @@ func Test_unmarshalEmoji(T *testing.T) { t.Parallel() var dest example - assert.Error(t, unmarshalEmoji([]byte("not valid ecoji data"), &dest)) + test.Error(t, unmarshalEmoji([]byte("not valid ecoji data"), &dest)) }) T.Run("with valid ecoji but invalid gob data", func(t *testing.T) { t.Parallel() var buf bytes.Buffer - require.NoError(t, ecoji.EncodeV2(bytes.NewReader([]byte("not valid gob data")), &buf, 76)) + must.NoError(t, ecoji.EncodeV2(bytes.NewReader([]byte("not valid gob data")), &buf, 76)) var dest example - assert.Error(t, unmarshalEmoji(buf.Bytes(), &dest)) + test.Error(t, unmarshalEmoji(buf.Bytes(), &dest)) }) } @@ -203,6 +203,6 @@ func Test_tomlMarshalFunc(T *testing.T) { t.Parallel() _, err := tomlMarshalFunc(make(chan int)) - assert.Error(t, err) + test.Error(t, err) }) } diff --git a/encoding/config_test.go b/encoding/config_test.go index 6f84afc..8974a99 100644 --- a/encoding/config_test.go +++ b/encoding/config_test.go @@ -3,7 +3,7 @@ package encoding import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -17,7 +17,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ContentType: contentTypeJSON, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with invalid config", func(t *testing.T) { @@ -26,6 +26,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{} - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/encoding/content_type_test.go b/encoding/content_type_test.go index 0832365..6198d17 100644 --- a/encoding/content_type_test.go +++ b/encoding/content_type_test.go @@ -6,7 +6,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func Test_clientEncoder_ContentType(T *testing.T) { @@ -17,7 +17,7 @@ func Test_clientEncoder_ContentType(T *testing.T) { e := ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON) - assert.NotEmpty(t, e.ContentType()) + test.NotEq(t, "", e.ContentType()) }) } @@ -27,7 +27,7 @@ func Test_buildContentType(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, buildContentType("test")) + test.NotNil(t, buildContentType("test")) }) } @@ -37,25 +37,25 @@ func TestContentTypeToString(T *testing.T) { T.Run("with JSON", func(t *testing.T) { t.Parallel() - assert.NotEmpty(t, ContentTypeToString(ContentTypeJSON)) + test.NotEq(t, "", ContentTypeToString(ContentTypeJSON)) }) T.Run("with XML", func(t *testing.T) { t.Parallel() - assert.NotEmpty(t, ContentTypeToString(ContentTypeXML)) + test.NotEq(t, "", ContentTypeToString(ContentTypeXML)) }) T.Run("with Emoji", func(t *testing.T) { t.Parallel() - assert.NotEmpty(t, ContentTypeToString(ContentTypeEmoji)) + test.NotEq(t, "", ContentTypeToString(ContentTypeEmoji)) }) T.Run("with invalid input", func(t *testing.T) { t.Parallel() - assert.Empty(t, ContentTypeToString(nil)) + test.EqOp(t, "", ContentTypeToString(nil)) }) } @@ -65,36 +65,36 @@ func Test_contentTypeFromString(T *testing.T) { T.Run("with JSON", func(t *testing.T) { t.Parallel() - assert.Equal(t, ContentTypeJSON, contentTypeFromString(contentTypeJSON)) + test.EqOp(t, ContentTypeJSON, contentTypeFromString(contentTypeJSON)) }) T.Run("with XML", func(t *testing.T) { t.Parallel() - assert.Equal(t, ContentTypeXML, contentTypeFromString(contentTypeXML)) + test.EqOp(t, ContentTypeXML, contentTypeFromString(contentTypeXML)) }) T.Run("with TOML", func(t *testing.T) { t.Parallel() - assert.Equal(t, ContentTypeTOML, contentTypeFromString(contentTypeTOML)) + test.EqOp(t, ContentTypeTOML, contentTypeFromString(contentTypeTOML)) }) T.Run("with YAML", func(t *testing.T) { t.Parallel() - assert.Equal(t, ContentTypeYAML, contentTypeFromString(contentTypeYAML)) + test.EqOp(t, ContentTypeYAML, contentTypeFromString(contentTypeYAML)) }) T.Run("with Emoji", func(t *testing.T) { t.Parallel() - assert.Equal(t, ContentTypeEmoji, contentTypeFromString(contentTypeEmoji)) + test.EqOp(t, ContentTypeEmoji, contentTypeFromString(contentTypeEmoji)) }) T.Run("with unknown defaults to JSON", func(t *testing.T) { t.Parallel() - assert.Equal(t, ContentTypeJSON, contentTypeFromString("unknown")) + test.EqOp(t, ContentTypeJSON, contentTypeFromString("unknown")) }) } diff --git a/encoding/do_test.go b/encoding/do_test.go index 2538687..77b35c7 100644 --- a/encoding/do_test.go +++ b/encoding/do_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterServerEncoderDecoder(T *testing.T) { @@ -25,11 +25,11 @@ func TestRegisterServerEncoderDecoder(T *testing.T) { RegisterServerEncoderDecoder(i) ct, err := do.Invoke[ContentType](i) - require.NoError(t, err) - assert.NotNil(t, ct) + must.NoError(t, err) + test.NotNil(t, ct) sed, err := do.Invoke[ServerEncoderDecoder](i) - require.NoError(t, err) - assert.NotNil(t, sed) + must.NoError(t, err) + test.NotNil(t, sed) }) } diff --git a/encoding/providers_test.go b/encoding/providers_test.go index e142157..0e79c9d 100644 --- a/encoding/providers_test.go +++ b/encoding/providers_test.go @@ -3,7 +3,7 @@ package encoding import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestProvideContentType(T *testing.T) { @@ -12,6 +12,6 @@ func TestProvideContentType(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.Equal(t, ContentTypeJSON, ProvideContentType(Config{ContentType: "application/json"})) + test.EqOp(t, ContentTypeJSON, ProvideContentType(Config{ContentType: "application/json"})) }) } diff --git a/encoding/server_encoder_decoder_test.go b/encoding/server_encoder_decoder_test.go index 81a499c..2292cd6 100644 --- a/encoding/server_encoder_decoder_test.go +++ b/encoding/server_encoder_decoder_test.go @@ -15,8 +15,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "gopkg.in/yaml.v3" ) @@ -84,7 +84,7 @@ func TestServerEncoderDecoder_encodeResponse(T *testing.T) { ex := &example{Name: "name"} encoderDecoder, ok := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), tc.contentType).(*serverEncoderDecoder) - require.True(t, ok) + must.True(t, ok) ctx := t.Context() res := httptest.NewRecorder() @@ -92,7 +92,7 @@ func TestServerEncoderDecoder_encodeResponse(T *testing.T) { encoderDecoder.encodeResponse(ctx, res, ex, http.StatusOK) actual := res.Body.String() - assert.Equal(t, tc.expectedResponse, actual) + test.EqOp(t, tc.expectedResponse, actual) }) } @@ -101,7 +101,7 @@ func TestServerEncoderDecoder_encodeResponse(T *testing.T) { ex := &example{Name: "name"} encoderDecoder, ok := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeEmoji).(*serverEncoderDecoder) - require.True(t, ok) + must.True(t, ok) ctx := t.Context() res := httptest.NewRecorder() @@ -109,7 +109,7 @@ func TestServerEncoderDecoder_encodeResponse(T *testing.T) { encoderDecoder.encodeResponse(ctx, res, ex, http.StatusOK) actual := res.Body.String() - assert.NotEmpty(t, actual) + test.NotEq(t, "", actual) }) T.Run("defaults to JSON", func(t *testing.T) { @@ -117,13 +117,13 @@ func TestServerEncoderDecoder_encodeResponse(T *testing.T) { expectation := "name" ex := &example{Name: expectation} encoderDecoder, ok := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON).(*serverEncoderDecoder) - require.True(t, ok) + must.True(t, ok) ctx := t.Context() res := httptest.NewRecorder() encoderDecoder.encodeResponse(ctx, res, ex, http.StatusOK) - assert.Equal(t, res.Body.String(), fmt.Sprintf("{%q:%q}\n", "name", ex.Name)) + test.EqOp(t, fmt.Sprintf("{%q:%q}\n", "name", ex.Name), res.Body.String()) }) T.Run("with broken structure", func(t *testing.T) { @@ -131,13 +131,13 @@ func TestServerEncoderDecoder_encodeResponse(T *testing.T) { expectation := "name" ex := &broken{Name: json.Number(expectation)} encoderDecoder, ok := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON).(*serverEncoderDecoder) - require.True(t, ok) + must.True(t, ok) ctx := t.Context() res := httptest.NewRecorder() encoderDecoder.encodeResponse(ctx, res, ex, http.StatusOK) - assert.Empty(t, res.Body.String()) + test.EqOp(t, "", res.Body.String()) }) } @@ -154,7 +154,7 @@ func TestServerEncoderDecoder_MustEncodeJSON(T *testing.T) { ` actual := string(encoderDecoder.MustEncodeJSON(ctx, &example{Name: t.Name()})) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) T.Run("with panic", func(t *testing.T) { @@ -164,7 +164,7 @@ func TestServerEncoderDecoder_MustEncodeJSON(T *testing.T) { encoderDecoder := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON) defer func() { - assert.NotNil(t, recover()) + test.NotNil(t, recover()) }() encoderDecoder.MustEncodeJSON(ctx, &broken{Name: json.Number(t.Name())}) @@ -205,7 +205,7 @@ func TestServerEncoderDecoder_MustEncode(T *testing.T) { actual := string(encoderDecoder.MustEncode(ctx, &example{Name: t.Name()})) - assert.Equal(t, tc.expected, actual) + test.EqOp(t, tc.expected, actual) }) } @@ -216,7 +216,7 @@ func TestServerEncoderDecoder_MustEncode(T *testing.T) { encoderDecoder := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeEmoji) actual := string(encoderDecoder.MustEncode(ctx, &example{Name: t.Name()})) - assert.NotEmpty(t, actual) + test.NotEq(t, "", actual) }) T.Run("with broken struct", func(t *testing.T) { @@ -224,10 +224,10 @@ func TestServerEncoderDecoder_MustEncode(T *testing.T) { ctx := t.Context() encoderDecoder, ok := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON).(*serverEncoderDecoder) - require.True(t, ok) + must.True(t, ok) defer func() { - assert.NotNil(t, recover()) + test.NotNil(t, recover()) }() encoderDecoder.MustEncode(ctx, &broken{Name: json.Number(t.Name())}) @@ -249,8 +249,8 @@ func TestServerEncoderDecoder_EncodeResponseWithStatus(T *testing.T) { expected := 666 encoderDecoder.EncodeResponseWithStatus(ctx, res, ex, expected) - assert.Equal(t, expected, res.Code, "expected code to be %d, but got %d", expected, res.Code) - assert.Equal(t, res.Body.String(), fmt.Sprintf("{%q:%q}\n", "name", ex.Name)) + test.EqOp(t, expected, res.Code, test.Sprintf("expected code to be %d, but got %d", expected, res.Code)) + test.EqOp(t, fmt.Sprintf("{%q:%q}\n", "name", ex.Name), res.Body.String()) }) } @@ -299,7 +299,7 @@ func TestServerEncoderDecoder_DecodeRequest(T *testing.T) { encoderDecoder := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), tc.contentType) bs, err := tc.marshaller(e) - require.NoError(t, err) + must.NoError(t, err) req, err := http.NewRequestWithContext( ctx, @@ -307,12 +307,12 @@ func TestServerEncoderDecoder_DecodeRequest(T *testing.T) { "https://whatever.whocares.gov", bytes.NewReader(bs), ) - require.NoError(t, err) + must.NoError(t, err) req.Header.Set(ContentTypeHeaderKey, ContentTypeToString(tc.contentType)) var x example - assert.NoError(t, encoderDecoder.DecodeRequest(ctx, req, &x)) - assert.Equal(t, x.Name, e.Name) + test.NoError(t, encoderDecoder.DecodeRequest(ctx, req, &x)) + test.EqOp(t, e.Name, x.Name) }) } } @@ -355,9 +355,9 @@ func Test_serverEncoderDecoder_DecodeBytes(T *testing.T) { encoderDecoder := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), tc.contentType) var dest *example - assert.NoError(t, encoderDecoder.DecodeBytes(ctx, tc.data, &dest)) + test.NoError(t, encoderDecoder.DecodeBytes(ctx, tc.data, &dest)) - assert.Equal(t, goodDataExpectation, dest) + test.Eq(t, goodDataExpectation, dest) }) } } @@ -370,14 +370,14 @@ func TestServerEncoderDecoder_RespondWithData(T *testing.T) { ctx := t.Context() encoderDecoder, ok := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON).(*serverEncoderDecoder) - require.True(t, ok) + must.True(t, ok) res := httptest.NewRecorder() encoderDecoder.RespondWithData(ctx, res, &example{Name: t.Name()}) - assert.Equal(t, http.StatusOK, res.Code) - assert.NotEmpty(t, res.Body.String()) + test.EqOp(t, http.StatusOK, res.Code) + test.NotEq(t, "", res.Body.String()) }) } @@ -390,7 +390,7 @@ func Test_tomlDecoder_Decode(T *testing.T) { d := newTomlDecoder(&errReader{}) var dest example - assert.Error(t, d.Decode(&dest)) + test.Error(t, d.Decode(&dest)) }) } @@ -401,14 +401,14 @@ func Test_emojiEncoder_Encode(T *testing.T) { t.Parallel() enc := newEmojiEncoder(&bytes.Buffer{}) - assert.Error(t, enc.Encode(make(chan int))) + test.Error(t, enc.Encode(make(chan int))) }) T.Run("with write error", func(t *testing.T) { t.Parallel() enc := newEmojiEncoder(&errWriter{}) - assert.Error(t, enc.Encode(&example{Name: "test"})) + test.Error(t, enc.Encode(&example{Name: "test"})) }) } @@ -421,7 +421,7 @@ func Test_emojiDecoder_Decode(T *testing.T) { d := newEmojiDecoder(&errReader{}) var dest example - assert.Error(t, d.Decode(&dest)) + test.Error(t, d.Decode(&dest)) }) } @@ -435,14 +435,14 @@ func TestServerEncoderDecoder_DecodeRequest_bodyCloseError(T *testing.T) { encoderDecoder := ProvideServerEncoderDecoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ContentTypeJSON) data, err := json.Marshal(&example{Name: "test"}) - require.NoError(t, err) + must.NoError(t, err) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://whatever.whocares.gov", &errorCloser{Reader: bytes.NewReader(data)}) - require.NoError(t, err) + must.NoError(t, err) req.Header.Set(ContentTypeHeaderKey, contentTypeJSON) var dest example - assert.NoError(t, encoderDecoder.DecodeRequest(ctx, req, &dest)) - assert.Equal(t, "test", dest.Name) + test.NoError(t, encoderDecoder.DecodeRequest(ctx, req, &dest)) + test.EqOp(t, "test", dest.Name) }) } diff --git a/encoding/utils_test.go b/encoding/utils_test.go index 15c3964..a9a912d 100644 --- a/encoding/utils_test.go +++ b/encoding/utils_test.go @@ -5,8 +5,8 @@ import ( "io" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestDecode(T *testing.T) { @@ -16,21 +16,21 @@ func TestDecode(T *testing.T) { t.Parallel() var dest example - assert.NoError(t, Decode([]byte(`{"name":"test"}`), nil, &dest)) + test.NoError(t, Decode([]byte(`{"name":"test"}`), nil, &dest)) }) T.Run("with explicit content type", func(t *testing.T) { t.Parallel() var dest example - assert.NoError(t, Decode([]byte(`test`), ContentTypeXML, &dest)) + test.NoError(t, Decode([]byte(`test`), ContentTypeXML, &dest)) }) T.Run("with invalid data", func(t *testing.T) { t.Parallel() var dest example - assert.Error(t, Decode([]byte(`{invalid`), nil, &dest)) + test.Error(t, Decode([]byte(`{invalid`), nil, &dest)) }) } @@ -41,21 +41,21 @@ func TestMustEncode(T *testing.T) { t.Parallel() result := MustEncode(&example{Name: t.Name()}, nil) - assert.NotEmpty(t, result) + test.SliceNotEmpty(t, result) }) T.Run("with explicit content type", func(t *testing.T) { t.Parallel() result := MustEncode(&example{Name: t.Name()}, ContentTypeXML) - assert.NotEmpty(t, result) + test.SliceNotEmpty(t, result) }) T.Run("panics with un-encodable data", func(t *testing.T) { t.Parallel() defer func() { - assert.NotNil(t, recover()) + test.NotNil(t, recover()) }() MustEncode(&broken{Name: json.Number(t.Name())}, nil) @@ -83,7 +83,7 @@ func TestMustDecode(T *testing.T) { t.Parallel() defer func() { - assert.NotNil(t, recover()) + test.NotNil(t, recover()) }() var dest example @@ -98,7 +98,7 @@ func TestMustEncodeJSON(T *testing.T) { t.Parallel() result := MustEncodeJSON(&example{Name: t.Name()}) - assert.NotEmpty(t, result) + test.SliceNotEmpty(t, result) }) } @@ -109,14 +109,14 @@ func TestDecodeJSON(T *testing.T) { t.Parallel() var dest example - assert.NoError(t, DecodeJSON([]byte(`{"name":"test"}`), &dest)) + test.NoError(t, DecodeJSON([]byte(`{"name":"test"}`), &dest)) }) T.Run("with invalid data", func(t *testing.T) { t.Parallel() var dest example - assert.Error(t, DecodeJSON([]byte(`{invalid`), &dest)) + test.Error(t, DecodeJSON([]byte(`{invalid`), &dest)) }) } @@ -138,10 +138,10 @@ func TestMustJSONIntoReader(T *testing.T) { t.Parallel() reader := MustJSONIntoReader(&example{Name: t.Name()}) - require.NotNil(t, reader) + must.NotNil(t, reader) data, err := io.ReadAll(reader) - require.NoError(t, err) - assert.NotEmpty(t, data) + must.NoError(t, err) + test.SliceNotEmpty(t, data) }) } diff --git a/errors/errors_test.go b/errors/errors_test.go index 9f91026..5dad694 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -5,8 +5,8 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestSentinelErrors(T *testing.T) { @@ -14,39 +14,39 @@ func TestSentinelErrors(T *testing.T) { T.Run("ErrNilInputParameter", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrNilInputParameter) - assert.Contains(t, ErrNilInputParameter.Error(), "nil") + test.NotNil(t, ErrNilInputParameter) + test.StrContains(t, ErrNilInputParameter.Error(), "nil") }) T.Run("ErrEmptyInputParameter", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrEmptyInputParameter) - assert.Contains(t, ErrEmptyInputParameter.Error(), "empty") + test.NotNil(t, ErrEmptyInputParameter) + test.StrContains(t, ErrEmptyInputParameter.Error(), "empty") }) T.Run("ErrNilInputProvided", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrNilInputProvided) - assert.Contains(t, ErrNilInputProvided.Error(), "nil input") + test.NotNil(t, ErrNilInputProvided) + test.StrContains(t, ErrNilInputProvided.Error(), "nil input") }) T.Run("ErrInvalidIDProvided", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrInvalidIDProvided) - assert.Contains(t, ErrInvalidIDProvided.Error(), "ID") + test.NotNil(t, ErrInvalidIDProvided) + test.StrContains(t, ErrInvalidIDProvided.Error(), "ID") }) T.Run("ErrEmptyInputProvided", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrEmptyInputProvided) - assert.Contains(t, ErrEmptyInputProvided.Error(), "empty") + test.NotNil(t, ErrEmptyInputProvided) + test.StrContains(t, ErrEmptyInputProvided.Error(), "empty") }) T.Run("sentinels are distinct", func(t *testing.T) { t.Parallel() - assert.False(t, errors.Is(ErrNilInputParameter, ErrEmptyInputParameter)) - assert.False(t, errors.Is(ErrNilInputProvided, ErrInvalidIDProvided)) - assert.False(t, errors.Is(ErrEmptyInputProvided, ErrNilInputProvided)) + test.False(t, errors.Is(ErrNilInputParameter, ErrEmptyInputParameter)) + test.False(t, errors.Is(ErrNilInputProvided, ErrInvalidIDProvided)) + test.False(t, errors.Is(ErrEmptyInputProvided, ErrNilInputProvided)) }) } @@ -56,8 +56,8 @@ func TestNew(T *testing.T) { T.Run("creates error with message", func(t *testing.T) { t.Parallel() err := New("test error") - require.Error(t, err) - assert.Equal(t, "test error", err.Error()) + must.Error(t, err) + test.EqError(t, err, "test error") }) } @@ -67,9 +67,9 @@ func TestNewf(T *testing.T) { T.Run("creates formatted error", func(t *testing.T) { t.Parallel() err := Newf("error %d: %s", 42, "details") - require.Error(t, err) - assert.Contains(t, err.Error(), "42") - assert.Contains(t, err.Error(), "details") + must.Error(t, err) + test.StrContains(t, err.Error(), "42") + test.StrContains(t, err.Error(), "details") }) } @@ -79,8 +79,8 @@ func TestErrorf(T *testing.T) { T.Run("creates formatted error", func(t *testing.T) { t.Parallel() err := Errorf("something %s", "failed") - require.Error(t, err) - assert.Contains(t, err.Error(), "something failed") + must.Error(t, err) + test.StrContains(t, err.Error(), "something failed") }) } @@ -91,14 +91,14 @@ func TestWrap(T *testing.T) { t.Parallel() inner := fmt.Errorf("inner") wrapped := Wrap(inner, "outer") - require.Error(t, wrapped) - assert.True(t, errors.Is(wrapped, inner)) - assert.Contains(t, wrapped.Error(), "outer") + must.Error(t, wrapped) + test.ErrorIs(t, wrapped, inner) + test.StrContains(t, wrapped.Error(), "outer") }) T.Run("nil error returns nil", func(t *testing.T) { t.Parallel() - assert.Nil(t, Wrap(nil, "outer")) + test.Nil(t, Wrap(nil, "outer")) }) } @@ -109,8 +109,8 @@ func TestWrapf(T *testing.T) { t.Parallel() inner := fmt.Errorf("inner") wrapped := Wrapf(inner, "outer %d", 1) - require.Error(t, wrapped) - assert.True(t, errors.Is(wrapped, inner)) - assert.Contains(t, wrapped.Error(), "outer 1") + must.Error(t, wrapped) + test.ErrorIs(t, wrapped, inner) + test.StrContains(t, wrapped.Error(), "outer 1") }) } diff --git a/errors/grpc/grpc_interceptor_test.go b/errors/grpc/grpc_interceptor_test.go index 655cafe..93990e7 100644 --- a/errors/grpc/grpc_interceptor_test.go +++ b/errors/grpc/grpc_interceptor_test.go @@ -8,8 +8,8 @@ import ( platformerrors "github.com/verygoodsoftwarenotvirus/platform/v5/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -21,14 +21,14 @@ func TestDecodeErrorFromStatus(T *testing.T) { T.Run("nil error returns nil", func(t *testing.T) { t.Parallel() - assert.Nil(t, DecodeErrorFromStatus(context.Background(), nil)) + test.Nil(t, DecodeErrorFromStatus(context.Background(), nil)) }) T.Run("non-status error returned as-is", func(t *testing.T) { t.Parallel() original := errors.New("plain error") result := DecodeErrorFromStatus(context.Background(), original) - assert.Equal(t, original, result) + test.ErrorIs(t, result, original) }) T.Run("status error without details returns original", func(t *testing.T) { @@ -36,7 +36,7 @@ func TestDecodeErrorFromStatus(T *testing.T) { st := status.New(codes.NotFound, "not found") err := st.Err() result := DecodeErrorFromStatus(context.Background(), err) - assert.Error(t, result) + test.Error(t, result) }) T.Run("round-trips a platform sentinel error through encode/decode", func(t *testing.T) { @@ -47,17 +47,17 @@ func TestDecodeErrorFromStatus(T *testing.T) { // Encode using the interceptor helper detail := encodeErrorToDetails(ctx, original) - require.NotNil(t, detail) + must.NotNil(t, detail) // Build a status with details st := status.New(codes.InvalidArgument, original.Error()) stWithDetails, err := st.WithDetails(detail) - require.NoError(t, err) + must.NoError(t, err) // Decode - the decoded error should contain the original message decoded := DecodeErrorFromStatus(ctx, stWithDetails.Err()) - require.Error(t, decoded) - assert.Contains(t, decoded.Error(), "nil") + must.Error(t, decoded) + test.StrContains(t, decoded.Error(), "nil") }) } @@ -67,22 +67,22 @@ func TestEncodeErrorToDetails(T *testing.T) { T.Run("encodes a platform error", func(t *testing.T) { t.Parallel() detail := encodeErrorToDetails(context.Background(), platformerrors.ErrNilInputParameter) - assert.NotNil(t, detail) - assert.Equal(t, encodedErrorTypeURL, detail.TypeUrl) + test.NotNil(t, detail) + test.EqOp(t, encodedErrorTypeURL, detail.TypeUrl) }) T.Run("encodes a wrapped error", func(t *testing.T) { t.Parallel() wrapped := platformerrors.Wrap(platformerrors.ErrInvalidIDProvided, "context") detail := encodeErrorToDetails(context.Background(), wrapped) - assert.NotNil(t, detail) + test.NotNil(t, detail) }) T.Run("encodes a simple error", func(t *testing.T) { t.Parallel() detail := encodeErrorToDetails(context.Background(), errors.New("simple")) // Even simple errors should encode (cockroachdb/errors handles them) - assert.NotNil(t, detail) + test.NotNil(t, detail) }) } @@ -98,8 +98,8 @@ func TestUnaryErrorEncodingInterceptor(T *testing.T) { } resp, err := interceptor(context.Background(), "req", &grpc.UnaryServerInfo{}, handler) - assert.NoError(t, err) - assert.Equal(t, "ok", resp) + test.NoError(t, err) + test.Eq[any](t, "ok", resp) }) T.Run("encodes platform error into status details", func(t *testing.T) { @@ -111,13 +111,13 @@ func TestUnaryErrorEncodingInterceptor(T *testing.T) { } resp, err := interceptor(context.Background(), "req", &grpc.UnaryServerInfo{}, handler) - assert.Nil(t, resp) - require.Error(t, err) + test.Nil(t, resp) + must.Error(t, err) st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.InvalidArgument, st.Code()) - assert.NotEmpty(t, st.Details()) + must.True(t, ok) + test.EqOp(t, codes.InvalidArgument, st.Code()) + test.SliceNotEmpty(t, st.Details()) }) T.Run("preserves existing status code for known errors", func(t *testing.T) { @@ -129,11 +129,11 @@ func TestUnaryErrorEncodingInterceptor(T *testing.T) { } _, err := interceptor(context.Background(), "req", &grpc.UnaryServerInfo{}, handler) - require.Error(t, err) + must.Error(t, err) st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.NotFound, st.Code()) + must.True(t, ok) + test.EqOp(t, codes.NotFound, st.Code()) }) T.Run("handler returning status error preserves message", func(t *testing.T) { @@ -145,11 +145,11 @@ func TestUnaryErrorEncodingInterceptor(T *testing.T) { } _, err := interceptor(context.Background(), "req", &grpc.UnaryServerInfo{}, handler) - require.Error(t, err) + must.Error(t, err) st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, "custom message", st.Message()) + must.True(t, ok) + test.EqOp(t, "custom message", st.Message()) }) T.Run("unknown error uses codes.Unknown", func(t *testing.T) { @@ -161,11 +161,11 @@ func TestUnaryErrorEncodingInterceptor(T *testing.T) { } _, err := interceptor(context.Background(), "req", &grpc.UnaryServerInfo{}, handler) - require.Error(t, err) + must.Error(t, err) st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.Unknown, st.Code()) + must.True(t, ok) + test.EqOp(t, codes.Unknown, st.Code()) }) } @@ -194,7 +194,7 @@ func TestStreamErrorEncodingInterceptor(T *testing.T) { ss := &mockServerStream{ctx: context.Background()} err := interceptor(nil, ss, &grpc.StreamServerInfo{}, handler) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("encodes platform error into status details", func(t *testing.T) { @@ -207,12 +207,12 @@ func TestStreamErrorEncodingInterceptor(T *testing.T) { ss := &mockServerStream{ctx: context.Background()} err := interceptor(nil, ss, &grpc.StreamServerInfo{}, handler) - require.Error(t, err) + must.Error(t, err) st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.InvalidArgument, st.Code()) - assert.NotEmpty(t, st.Details()) + must.True(t, ok) + test.EqOp(t, codes.InvalidArgument, st.Code()) + test.SliceNotEmpty(t, st.Details()) }) T.Run("unknown error uses codes.Unknown", func(t *testing.T) { @@ -225,11 +225,11 @@ func TestStreamErrorEncodingInterceptor(T *testing.T) { ss := &mockServerStream{ctx: context.Background()} err := interceptor(nil, ss, &grpc.StreamServerInfo{}, handler) - require.Error(t, err) + must.Error(t, err) st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, codes.Unknown, st.Code()) + must.True(t, ok) + test.EqOp(t, codes.Unknown, st.Code()) }) T.Run("handler returning status error preserves message", func(t *testing.T) { @@ -242,10 +242,10 @@ func TestStreamErrorEncodingInterceptor(T *testing.T) { ss := &mockServerStream{ctx: context.Background()} err := interceptor(nil, ss, &grpc.StreamServerInfo{}, handler) - require.Error(t, err) + must.Error(t, err) st, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, "not authed", st.Message()) + must.True(t, ok) + test.EqOp(t, "not authed", st.Message()) }) } diff --git a/errors/grpc/grpc_test.go b/errors/grpc/grpc_test.go index 3b59e8f..ff60514 100644 --- a/errors/grpc/grpc_test.go +++ b/errors/grpc/grpc_test.go @@ -9,7 +9,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/database" platformerrors "github.com/verygoodsoftwarenotvirus/platform/v5/errors" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" "google.golang.org/grpc/codes" ) @@ -19,69 +19,69 @@ func TestPlatformMapper_Map(T *testing.T) { T.Run("nil error returns ok=false", func(t *testing.T) { t.Parallel() _, ok := PlatformMapper.Map(nil) - assert.False(t, ok) + test.False(t, ok) }) T.Run("ErrUserAlreadyExists maps to AlreadyExists", func(t *testing.T) { t.Parallel() code, ok := PlatformMapper.Map(database.ErrUserAlreadyExists) - assert.True(t, ok) - assert.Equal(t, codes.AlreadyExists, code) + test.True(t, ok) + test.EqOp(t, codes.AlreadyExists, code) }) T.Run("sql.ErrNoRows maps to NotFound", func(t *testing.T) { t.Parallel() code, ok := PlatformMapper.Map(sql.ErrNoRows) - assert.True(t, ok) - assert.Equal(t, codes.NotFound, code) + test.True(t, ok) + test.EqOp(t, codes.NotFound, code) }) T.Run("ErrCircuitBroken maps to Unavailable", func(t *testing.T) { t.Parallel() code, ok := PlatformMapper.Map(circuitbreaking.ErrCircuitBroken) - assert.True(t, ok) - assert.Equal(t, codes.Unavailable, code) + test.True(t, ok) + test.EqOp(t, codes.Unavailable, code) }) T.Run("ErrNilInputParameter maps to InvalidArgument", func(t *testing.T) { t.Parallel() code, ok := PlatformMapper.Map(platformerrors.ErrNilInputParameter) - assert.True(t, ok) - assert.Equal(t, codes.InvalidArgument, code) + test.True(t, ok) + test.EqOp(t, codes.InvalidArgument, code) }) T.Run("ErrEmptyInputParameter maps to InvalidArgument", func(t *testing.T) { t.Parallel() code, ok := PlatformMapper.Map(platformerrors.ErrEmptyInputParameter) - assert.True(t, ok) - assert.Equal(t, codes.InvalidArgument, code) + test.True(t, ok) + test.EqOp(t, codes.InvalidArgument, code) }) T.Run("ErrNilInputProvided maps to InvalidArgument", func(t *testing.T) { t.Parallel() code, ok := PlatformMapper.Map(platformerrors.ErrNilInputProvided) - assert.True(t, ok) - assert.Equal(t, codes.InvalidArgument, code) + test.True(t, ok) + test.EqOp(t, codes.InvalidArgument, code) }) T.Run("ErrInvalidIDProvided maps to InvalidArgument", func(t *testing.T) { t.Parallel() code, ok := PlatformMapper.Map(platformerrors.ErrInvalidIDProvided) - assert.True(t, ok) - assert.Equal(t, codes.InvalidArgument, code) + test.True(t, ok) + test.EqOp(t, codes.InvalidArgument, code) }) T.Run("ErrEmptyInputProvided maps to InvalidArgument", func(t *testing.T) { t.Parallel() code, ok := PlatformMapper.Map(platformerrors.ErrEmptyInputProvided) - assert.True(t, ok) - assert.Equal(t, codes.InvalidArgument, code) + test.True(t, ok) + test.EqOp(t, codes.InvalidArgument, code) }) T.Run("unknown error returns ok=false", func(t *testing.T) { t.Parallel() _, ok := PlatformMapper.Map(errors.New("nope")) - assert.False(t, ok) + test.False(t, ok) }) } @@ -90,12 +90,12 @@ func TestMapToGRPC(T *testing.T) { T.Run("nil error returns OK", func(t *testing.T) { t.Parallel() - assert.Equal(t, codes.OK, MapToGRPC(nil, codes.Internal)) + test.EqOp(t, codes.OK, MapToGRPC(nil, codes.Internal)) }) T.Run("known platform error uses PlatformMapper", func(t *testing.T) { t.Parallel() - assert.Equal(t, codes.NotFound, MapToGRPC(sql.ErrNoRows, codes.Internal)) + test.EqOp(t, codes.NotFound, MapToGRPC(sql.ErrNoRows, codes.Internal)) }) T.Run("unknown error with no domain mappers returns default", func(t *testing.T) { @@ -104,7 +104,7 @@ func TestMapToGRPC(T *testing.T) { // so we test PlatformMapper directly for "unknown returns default" behavior above. code := MapToGRPC(errors.New("truly unknown error that no mapper handles"), codes.Aborted) // If a domain mapper catches it, that's fine; we just verify no panic. - assert.NotEqual(t, codes.OK, code) + test.NotEq(t, codes.OK, code) }) T.Run("domain mapper is consulted when platform mapper does not match", func(t *testing.T) { @@ -116,8 +116,8 @@ func TestMapToGRPC(T *testing.T) { // so we test the mapper interface directly to verify the flow. mapper := testGRPCMapper{err: customErr, code: codes.PermissionDenied} code, ok := mapper.Map(customErr) - assert.True(t, ok) - assert.Equal(t, codes.PermissionDenied, code) + test.True(t, ok) + test.EqOp(t, codes.PermissionDenied, code) }) } @@ -147,7 +147,7 @@ func TestRegisterGRPCErrorMapper(T *testing.T) { // After registration, MapToGRPC should find it code := MapToGRPC(customErr, codes.Internal) - assert.Equal(t, codes.ResourceExhausted, code) + test.EqOp(t, codes.ResourceExhausted, code) }) } @@ -158,7 +158,7 @@ func TestPrepareAndLogGRPCStatus(T *testing.T) { t.Parallel() err := PrepareAndLogGRPCStatus(sql.ErrNoRows, nil, nil, codes.Internal, "fetching thing %s", "abc") - assert.Error(t, err) + test.Error(t, err) }) T.Run("with nil error", func(t *testing.T) { @@ -166,13 +166,13 @@ func TestPrepareAndLogGRPCStatus(T *testing.T) { err := PrepareAndLogGRPCStatus(nil, nil, nil, codes.Internal, "something") // nil error maps to codes.OK, which may produce nil or a status with OK - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("with unknown error uses default code", func(t *testing.T) { t.Parallel() err := PrepareAndLogGRPCStatus(errors.New("unknown"), nil, nil, codes.DataLoss, "oops") - assert.Error(t, err) + test.Error(t, err) }) } diff --git a/errors/http/map_test.go b/errors/http/map_test.go index 49720dd..66bd744 100644 --- a/errors/http/map_test.go +++ b/errors/http/map_test.go @@ -10,7 +10,7 @@ import ( platformerrors "github.com/verygoodsoftwarenotvirus/platform/v5/errors" "github.com/verygoodsoftwarenotvirus/platform/v5/types" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestPlatformMapper_Map(T *testing.T) { @@ -19,72 +19,72 @@ func TestPlatformMapper_Map(T *testing.T) { T.Run("nil error returns ok=false", func(t *testing.T) { t.Parallel() _, _, ok := PlatformMapper.Map(nil) - assert.False(t, ok) + test.False(t, ok) }) T.Run("sql.ErrNoRows maps to ErrDataNotFound", func(t *testing.T) { t.Parallel() code, msg, ok := PlatformMapper.Map(sql.ErrNoRows) - assert.True(t, ok) - assert.Equal(t, types.ErrDataNotFound, code) - assert.Equal(t, "data not found", msg) + test.True(t, ok) + test.EqOp(t, types.ErrDataNotFound, code) + test.EqOp(t, "data not found", msg) }) T.Run("ErrUserAlreadyExists maps to ErrValidatingRequestInput", func(t *testing.T) { t.Parallel() code, msg, ok := PlatformMapper.Map(database.ErrUserAlreadyExists) - assert.True(t, ok) - assert.Equal(t, types.ErrValidatingRequestInput, code) - assert.Equal(t, "user already exists", msg) + test.True(t, ok) + test.EqOp(t, types.ErrValidatingRequestInput, code) + test.EqOp(t, "user already exists", msg) }) T.Run("ErrCircuitBroken maps to ErrCircuitBroken", func(t *testing.T) { t.Parallel() code, msg, ok := PlatformMapper.Map(circuitbreaking.ErrCircuitBroken) - assert.True(t, ok) - assert.Equal(t, types.ErrCircuitBroken, code) - assert.Equal(t, "service temporarily unavailable", msg) + test.True(t, ok) + test.EqOp(t, types.ErrCircuitBroken, code) + test.EqOp(t, "service temporarily unavailable", msg) }) T.Run("ErrNilInputParameter maps to ErrValidatingRequestInput", func(t *testing.T) { t.Parallel() code, _, ok := PlatformMapper.Map(platformerrors.ErrNilInputParameter) - assert.True(t, ok) - assert.Equal(t, types.ErrValidatingRequestInput, code) + test.True(t, ok) + test.EqOp(t, types.ErrValidatingRequestInput, code) }) T.Run("ErrEmptyInputParameter maps to ErrValidatingRequestInput", func(t *testing.T) { t.Parallel() code, _, ok := PlatformMapper.Map(platformerrors.ErrEmptyInputParameter) - assert.True(t, ok) - assert.Equal(t, types.ErrValidatingRequestInput, code) + test.True(t, ok) + test.EqOp(t, types.ErrValidatingRequestInput, code) }) T.Run("ErrNilInputProvided maps to ErrValidatingRequestInput", func(t *testing.T) { t.Parallel() code, _, ok := PlatformMapper.Map(platformerrors.ErrNilInputProvided) - assert.True(t, ok) - assert.Equal(t, types.ErrValidatingRequestInput, code) + test.True(t, ok) + test.EqOp(t, types.ErrValidatingRequestInput, code) }) T.Run("ErrInvalidIDProvided maps to ErrValidatingRequestInput", func(t *testing.T) { t.Parallel() code, _, ok := PlatformMapper.Map(platformerrors.ErrInvalidIDProvided) - assert.True(t, ok) - assert.Equal(t, types.ErrValidatingRequestInput, code) + test.True(t, ok) + test.EqOp(t, types.ErrValidatingRequestInput, code) }) T.Run("ErrEmptyInputProvided maps to ErrValidatingRequestInput", func(t *testing.T) { t.Parallel() code, _, ok := PlatformMapper.Map(platformerrors.ErrEmptyInputProvided) - assert.True(t, ok) - assert.Equal(t, types.ErrValidatingRequestInput, code) + test.True(t, ok) + test.EqOp(t, types.ErrValidatingRequestInput, code) }) T.Run("unknown error returns ok=false", func(t *testing.T) { t.Parallel() _, _, ok := PlatformMapper.Map(errors.New("nope")) - assert.False(t, ok) + test.False(t, ok) }) } @@ -94,43 +94,43 @@ func TestToAPIError(T *testing.T) { T.Run("nil error", func(t *testing.T) { t.Parallel() code, msg := ToAPIError(nil) - assert.Equal(t, types.ErrNothingSpecific, code) - assert.Empty(t, msg) + test.EqOp(t, types.ErrNothingSpecific, code) + test.EqOp(t, "", msg) }) T.Run("known platform error uses PlatformMapper", func(t *testing.T) { t.Parallel() code, msg := ToAPIError(sql.ErrNoRows) - assert.Equal(t, types.ErrDataNotFound, code) - assert.Equal(t, "data not found", msg) + test.EqOp(t, types.ErrDataNotFound, code) + test.EqOp(t, "data not found", msg) }) T.Run("unknown error returns fallback", func(t *testing.T) { t.Parallel() code, msg := ToAPIError(errors.New("totally unknown error that no mapper handles")) - assert.Equal(t, types.ErrTalkingToDatabase, code) - assert.Equal(t, "an error occurred", msg) + test.EqOp(t, types.ErrTalkingToDatabase, code) + test.EqOp(t, "an error occurred", msg) }) T.Run("circuit broken error", func(t *testing.T) { t.Parallel() code, msg := ToAPIError(circuitbreaking.ErrCircuitBroken) - assert.Equal(t, types.ErrCircuitBroken, code) - assert.Equal(t, "service temporarily unavailable", msg) + test.EqOp(t, types.ErrCircuitBroken, code) + test.EqOp(t, "service temporarily unavailable", msg) }) T.Run("ErrNilInputParameter", func(t *testing.T) { t.Parallel() code, msg := ToAPIError(platformerrors.ErrNilInputParameter) - assert.Equal(t, types.ErrValidatingRequestInput, code) - assert.Equal(t, "invalid input", msg) + test.EqOp(t, types.ErrValidatingRequestInput, code) + test.EqOp(t, "invalid input", msg) }) T.Run("ErrUserAlreadyExists", func(t *testing.T) { t.Parallel() code, msg := ToAPIError(database.ErrUserAlreadyExists) - assert.Equal(t, types.ErrValidatingRequestInput, code) - assert.Equal(t, "user already exists", msg) + test.EqOp(t, types.ErrValidatingRequestInput, code) + test.EqOp(t, "user already exists", msg) }) } @@ -159,7 +159,7 @@ func TestRegisterHTTPErrorMapper(T *testing.T) { RegisterHTTPErrorMapper(mapper) code, msg := ToAPIError(customErr) - assert.Equal(t, types.ErrorCode("E_CUSTOM"), code) - assert.Equal(t, "custom message", msg) + test.EqOp(t, types.ErrorCode("E_CUSTOM"), code) + test.EqOp(t, "custom message", msg) }) } diff --git a/eventstream/config/config_test.go b/eventstream/config/config_test.go index a9ea5e3..b3e0660 100644 --- a/eventstream/config/config_test.go +++ b/eventstream/config/config_test.go @@ -5,8 +5,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -20,7 +20,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderSSE, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("WebSocket provider", func(t *testing.T) { @@ -31,7 +31,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderWebSocket, } - assert.Error(t, cfg.ValidateWithContext(ctx), "websocket provider requires websocket config") + test.Error(t, cfg.ValidateWithContext(ctx), test.Sprintf("websocket provider requires websocket config")) }) T.Run("invalid provider", func(t *testing.T) { @@ -42,7 +42,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "invalid", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -56,8 +56,8 @@ func TestProvideEventStreamUpgrader(T *testing.T) { Provider: ProviderSSE, }) - require.NoError(t, err) - assert.NotNil(t, upgrader) + must.NoError(t, err) + test.NotNil(t, upgrader) }) T.Run("WebSocket", func(t *testing.T) { @@ -67,8 +67,8 @@ func TestProvideEventStreamUpgrader(T *testing.T) { Provider: ProviderWebSocket, }) - require.NoError(t, err) - assert.NotNil(t, upgrader) + must.NoError(t, err) + test.NotNil(t, upgrader) }) T.Run("invalid provider", func(t *testing.T) { @@ -76,7 +76,7 @@ func TestProvideEventStreamUpgrader(T *testing.T) { _, err := ProvideEventStreamUpgrader(nil, tracing.NewNoopTracerProvider(), &Config{}) - assert.Error(t, err) + test.Error(t, err) }) } @@ -90,8 +90,8 @@ func TestProvideBidirectionalEventStreamUpgrader(T *testing.T) { Provider: ProviderSSE, }) - assert.Error(t, err) - assert.Contains(t, err.Error(), "SSE does not support bidirectional") + test.Error(t, err) + test.StrContains(t, err.Error(), "SSE does not support bidirectional") }) T.Run("WebSocket", func(t *testing.T) { @@ -101,8 +101,8 @@ func TestProvideBidirectionalEventStreamUpgrader(T *testing.T) { Provider: ProviderWebSocket, }) - require.NoError(t, err) - assert.NotNil(t, upgrader) + must.NoError(t, err) + test.NotNil(t, upgrader) }) T.Run("invalid provider", func(t *testing.T) { @@ -110,6 +110,6 @@ func TestProvideBidirectionalEventStreamUpgrader(T *testing.T) { _, err := ProvideBidirectionalEventStreamUpgrader(nil, tracing.NewNoopTracerProvider(), &Config{}) - assert.Error(t, err) + test.Error(t, err) }) } diff --git a/eventstream/config/do_test.go b/eventstream/config/do_test.go index 594f042..c533478 100644 --- a/eventstream/config/do_test.go +++ b/eventstream/config/do_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterEventStreamUpgrader(T *testing.T) { @@ -26,8 +26,8 @@ func TestRegisterEventStreamUpgrader(T *testing.T) { RegisterEventStreamUpgrader(i) upgrader, err := do.Invoke[eventstream.EventStreamUpgrader](i) - require.NoError(t, err) - assert.NotNil(t, upgrader) + must.NoError(t, err) + test.NotNil(t, upgrader) }) } @@ -45,7 +45,7 @@ func TestRegisterBidirectionalEventStreamUpgrader(T *testing.T) { RegisterBidirectionalEventStreamUpgrader(i) upgrader, err := do.Invoke[eventstream.BidirectionalEventStreamUpgrader](i) - require.NoError(t, err) - assert.NotNil(t, upgrader) + must.NoError(t, err) + test.NotNil(t, upgrader) }) } diff --git a/eventstream/manager_test.go b/eventstream/manager_test.go index 0c10675..8658e3d 100644 --- a/eventstream/manager_test.go +++ b/eventstream/manager_test.go @@ -3,14 +3,17 @@ package eventstream import ( "context" "encoding/json" + "errors" "sync" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) +var errStub = errors.New("stub error") + // mockStream implements EventStream for testing. type mockStream struct { done chan struct{} @@ -61,12 +64,12 @@ func TestNewStreamManager(T *testing.T) { ctx := t.Context() m := NewStreamManager[EventStream](nil, nil) - require.NotNil(t, m) + must.NotNil(t, m) - assert.False(t, m.GroupHasStreams(ctx, "any")) - assert.Equal(t, 0, m.GetStreamCount(ctx, "any")) - assert.Nil(t, m.Get(ctx, "any", "any")) - assert.Empty(t, m.GetGroupStreams(ctx, "any")) + test.False(t, m.GroupHasStreams(ctx, "any")) + test.EqOp(t, 0, m.GetStreamCount(ctx, "any")) + test.Nil(t, m.Get(ctx, "any", "any")) + test.SliceEmpty(t, m.GetGroupStreams(ctx, "any")) }) } @@ -82,16 +85,16 @@ func TestStreamManager_Add_Get_Remove(T *testing.T) { m := NewStreamManager[EventStream](nil, nil) m.Add(ctx, "g1", "m1", stream) - assert.True(t, m.GroupHasStreams(ctx, "g1")) - assert.Equal(t, 1, m.GetStreamCount(ctx, "g1")) - assert.Equal(t, stream, m.Get(ctx, "g1", "m1")) - assert.Len(t, m.GetGroupStreams(ctx, "g1"), 1) + test.True(t, m.GroupHasStreams(ctx, "g1")) + test.EqOp(t, 1, m.GetStreamCount(ctx, "g1")) + test.True(t, EventStream(stream) == m.Get(ctx, "g1", "m1")) + test.SliceLen(t, 1, m.GetGroupStreams(ctx, "g1")) m.Remove(ctx, "g1", "m1") - assert.False(t, m.GroupHasStreams(ctx, "g1")) - assert.Equal(t, 0, m.GetStreamCount(ctx, "g1")) - assert.Nil(t, m.Get(ctx, "g1", "m1")) - assert.Empty(t, m.GetGroupStreams(ctx, "g1")) + test.False(t, m.GroupHasStreams(ctx, "g1")) + test.EqOp(t, 0, m.GetStreamCount(ctx, "g1")) + test.Nil(t, m.Get(ctx, "g1", "m1")) + test.SliceEmpty(t, m.GetGroupStreams(ctx, "g1")) }) } @@ -106,15 +109,15 @@ func TestStreamManager_Remove_empties_group(T *testing.T) { m := NewStreamManager[EventStream](nil, nil) m.Add(ctx, "g1", "m1", newMockStream()) m.Add(ctx, "g1", "m2", newMockStream()) - assert.Equal(t, 2, m.GetStreamCount(ctx, "g1")) + test.EqOp(t, 2, m.GetStreamCount(ctx, "g1")) m.Remove(ctx, "g1", "m1") - assert.Equal(t, 1, m.GetStreamCount(ctx, "g1")) - assert.NotNil(t, m.Get(ctx, "g1", "m2")) + test.EqOp(t, 1, m.GetStreamCount(ctx, "g1")) + test.NotNil(t, m.Get(ctx, "g1", "m2")) m.Remove(ctx, "g1", "m2") - assert.False(t, m.GroupHasStreams(ctx, "g1")) - assert.Equal(t, 0, m.GetStreamCount(ctx, "g1")) + test.False(t, m.GroupHasStreams(ctx, "g1")) + test.EqOp(t, 0, m.GetStreamCount(ctx, "g1")) }) } @@ -127,8 +130,8 @@ func TestStreamManager_Get_nonexistent(T *testing.T) { ctx := t.Context() m := NewStreamManager[EventStream](nil, nil) - assert.Nil(t, m.Get(ctx, "g1", "m1")) - assert.Nil(t, m.Get(ctx, "", "")) + test.Nil(t, m.Get(ctx, "g1", "m1")) + test.Nil(t, m.Get(ctx, "", "")) }) } @@ -152,10 +155,10 @@ func TestStreamManager_BroadcastToGroup(T *testing.T) { } m.BroadcastToGroup(ctx, "g1", event) - assert.Len(t, s1.sentEvents(), 1) - assert.Equal(t, "test", s1.sentEvents()[0].Type) - assert.Len(t, s2.sentEvents(), 1) - assert.Equal(t, "test", s2.sentEvents()[0].Type) + test.SliceLen(t, 1, s1.sentEvents()) + test.EqOp(t, "test", s1.sentEvents()[0].Type) + test.SliceLen(t, 1, s2.sentEvents()) + test.EqOp(t, "test", s2.sentEvents()[0].Type) }) T.Run("empty group", func(t *testing.T) { @@ -193,9 +196,9 @@ func TestStreamManager_BroadcastToGroupFiltered(T *testing.T) { return memberID == "m2" }) - assert.Empty(t, s1.sentEvents()) - assert.Len(t, s2.sentEvents(), 1) - assert.Equal(t, "filtered", s2.sentEvents()[0].Type) + test.SliceEmpty(t, s1.sentEvents()) + test.SliceLen(t, 1, s2.sentEvents()) + test.EqOp(t, "filtered", s2.sentEvents()[0].Type) }) T.Run("none match", func(t *testing.T) { @@ -209,7 +212,7 @@ func TestStreamManager_BroadcastToGroupFiltered(T *testing.T) { m.BroadcastToGroupFiltered(ctx, "g1", &Event{Type: "x"}, func(string) bool { return false }) - assert.Empty(t, s1.sentEvents()) + test.SliceEmpty(t, s1.sentEvents()) }) } @@ -229,11 +232,11 @@ func TestStreamManager_SendToMember(T *testing.T) { event := &Event{Type: "direct", Payload: json.RawMessage(`"hi"`)} err := m.SendToMember(ctx, "g1", "m1", event) - require.NoError(t, err) + must.NoError(t, err) - assert.Len(t, s1.sentEvents(), 1) - assert.Equal(t, "direct", s1.sentEvents()[0].Type) - assert.Empty(t, s2.sentEvents()) + test.SliceLen(t, 1, s1.sentEvents()) + test.EqOp(t, "direct", s1.sentEvents()[0].Type) + test.SliceEmpty(t, s2.sentEvents()) }) T.Run("nonexistent member returns nil", func(t *testing.T) { @@ -243,7 +246,7 @@ func TestStreamManager_SendToMember(T *testing.T) { m := NewStreamManager[EventStream](nil, nil) err := m.SendToMember(ctx, "g1", "m1", &Event{Type: "x"}) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("nonexistent group returns nil", func(t *testing.T) { @@ -253,7 +256,7 @@ func TestStreamManager_SendToMember(T *testing.T) { m := NewStreamManager[EventStream](nil, nil) err := m.SendToMember(ctx, "g999", "m1", &Event{Type: "x"}) - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -266,10 +269,10 @@ func TestStreamManager_GroupHasStreams(T *testing.T) { ctx := t.Context() m := NewStreamManager[EventStream](nil, nil) - assert.False(t, m.GroupHasStreams(ctx, "g1")) + test.False(t, m.GroupHasStreams(ctx, "g1")) m.Add(ctx, "g1", "m1", newMockStream()) - assert.True(t, m.GroupHasStreams(ctx, "g1")) + test.True(t, m.GroupHasStreams(ctx, "g1")) }) } @@ -282,13 +285,13 @@ func TestStreamManager_GetStreamCount(T *testing.T) { ctx := t.Context() m := NewStreamManager[EventStream](nil, nil) - assert.Equal(t, 0, m.GetStreamCount(ctx, "g1")) + test.EqOp(t, 0, m.GetStreamCount(ctx, "g1")) m.Add(ctx, "g1", "m1", newMockStream()) - assert.Equal(t, 1, m.GetStreamCount(ctx, "g1")) + test.EqOp(t, 1, m.GetStreamCount(ctx, "g1")) m.Add(ctx, "g1", "m2", newMockStream()) - assert.Equal(t, 2, m.GetStreamCount(ctx, "g1")) + test.EqOp(t, 2, m.GetStreamCount(ctx, "g1")) }) } @@ -321,7 +324,7 @@ func TestStreamManager_GetGroupStreams(T *testing.T) { m.Add(ctx, "g1", "m2", s2) streams := m.GetGroupStreams(ctx, "g1") - assert.Len(t, streams, 2) + test.SliceLen(t, 2, streams) }) T.Run("nonexistent group", func(t *testing.T) { @@ -331,7 +334,7 @@ func TestStreamManager_GetGroupStreams(T *testing.T) { m := NewStreamManager[EventStream](nil, nil) streams := m.GetGroupStreams(ctx, "g1") - assert.Empty(t, streams) + test.SliceEmpty(t, streams) }) } @@ -356,13 +359,13 @@ func TestStreamManager_BroadcastToGroup_with_failing_stream(T *testing.T) { // (we can't guarantee order due to map iteration, but we can check that // at least the non-failing stream received it) time.Sleep(10 * time.Millisecond) - assert.Len(t, s2.sentEvents(), 1) + test.SliceLen(t, 1, s2.sentEvents()) }) } // failingStream is a stream that always returns an error on Send. type failingStream struct{} -func (f *failingStream) Send(context.Context, *Event) error { return assert.AnError } +func (f *failingStream) Send(context.Context, *Event) error { return errStub } func (f *failingStream) Done() <-chan struct{} { return make(chan struct{}) } func (f *failingStream) Close() error { return nil } diff --git a/eventstream/sse/sse_test.go b/eventstream/sse/sse_test.go index 17174ba..9de0b57 100644 --- a/eventstream/sse/sse_test.go +++ b/eventstream/sse/sse_test.go @@ -11,8 +11,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/eventstream" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNewUpgrader(T *testing.T) { @@ -22,7 +22,7 @@ func TestNewUpgrader(T *testing.T) { t.Parallel() u := NewUpgrader(tracing.NewNoopTracerProvider()) - assert.NotNil(t, u) + test.NotNil(t, u) }) } @@ -46,18 +46,18 @@ func TestUpgrader_UpgradeToEventStream(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) defer stream.Close() - assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) - assert.Equal(t, "no-cache", resp.Header.Get("Cache-Control")) - assert.Equal(t, "keep-alive", resp.Header.Get("Connection")) + test.EqOp(t, "text/event-stream", resp.Header.Get("Content-Type")) + test.EqOp(t, "no-cache", resp.Header.Get("Cache-Control")) + test.EqOp(t, "keep-alive", resp.Header.Get("Connection")) }) T.Run("response writer does not support flushing", func(t *testing.T) { @@ -69,9 +69,9 @@ func TestUpgrader_UpgradeToEventStream(T *testing.T) { r := httptest.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody) stream, err := u.UpgradeToEventStream(w, r) - assert.Nil(t, stream) - assert.Error(t, err) - assert.Contains(t, err.Error(), "streaming not supported") + test.Nil(t, stream) + test.Error(t, err) + test.StrContains(t, err.Error(), "streaming not supported") }) } @@ -95,34 +95,34 @@ func TestSSEStream_Send(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) defer stream.Close() sendErr := stream.Send(t.Context(), &eventstream.Event{ Type: "test_event", Payload: json.RawMessage(`{"msg":"hello"}`), }) - require.NoError(t, sendErr) + must.NoError(t, sendErr) scanner := bufio.NewScanner(resp.Body) // Read "event: test_event" - require.True(t, scanner.Scan()) - assert.Equal(t, "event: test_event", scanner.Text()) + must.True(t, scanner.Scan()) + test.EqOp(t, "event: test_event", scanner.Text()) // Read "data: {\"msg\":\"hello\"}" - require.True(t, scanner.Scan()) - assert.Equal(t, `data: {"msg":"hello"}`, scanner.Text()) + must.True(t, scanner.Scan()) + test.EqOp(t, `data: {"msg":"hello"}`, scanner.Text()) // Read empty line (event separator) - require.True(t, scanner.Scan()) - assert.Equal(t, "", scanner.Text()) + must.True(t, scanner.Scan()) + test.EqOp(t, "", scanner.Text()) }) T.Run("event without type", func(t *testing.T) { @@ -142,29 +142,29 @@ func TestSSEStream_Send(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) defer stream.Close() sendErr := stream.Send(t.Context(), &eventstream.Event{ Payload: json.RawMessage(`{"x":1}`), }) - require.NoError(t, sendErr) + must.NoError(t, sendErr) scanner := bufio.NewScanner(resp.Body) // No "event:" line, just data - require.True(t, scanner.Scan()) - assert.Equal(t, `data: {"x":1}`, scanner.Text()) + must.True(t, scanner.Scan()) + test.EqOp(t, `data: {"x":1}`, scanner.Text()) // Empty line (event separator) - require.True(t, scanner.Scan()) - assert.Equal(t, "", scanner.Text()) + must.True(t, scanner.Scan()) + test.EqOp(t, "", scanner.Text()) }) T.Run("multiple events", func(t *testing.T) { @@ -184,13 +184,13 @@ func TestSSEStream_Send(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) defer stream.Close() for i, name := range []string{"first", "second", "third"} { @@ -198,19 +198,19 @@ func TestSSEStream_Send(T *testing.T) { Type: "msg", Payload: json.RawMessage(`"` + name + `"`), }) - require.NoError(t, sendErr, "send %d", i) + must.NoError(t, sendErr, must.Sprintf("send %d", i)) } scanner := bufio.NewScanner(resp.Body) for _, name := range []string{"first", "second", "third"} { - require.True(t, scanner.Scan()) - assert.Equal(t, "event: msg", scanner.Text()) + must.True(t, scanner.Scan()) + test.EqOp(t, "event: msg", scanner.Text()) - require.True(t, scanner.Scan()) - assert.Equal(t, `data: "`+name+`"`, scanner.Text()) + must.True(t, scanner.Scan()) + test.EqOp(t, `data: "`+name+`"`, scanner.Text()) - require.True(t, scanner.Scan()) - assert.Equal(t, "", scanner.Text()) + must.True(t, scanner.Scan()) + test.EqOp(t, "", scanner.Text()) } }) @@ -231,22 +231,22 @@ func TestSSEStream_Send(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) - require.NoError(t, stream.Close()) + must.NoError(t, stream.Close()) sendErr := stream.Send(t.Context(), &eventstream.Event{ Type: "test", Payload: json.RawMessage(`{}`), }) - assert.Error(t, sendErr) - assert.Contains(t, sendErr.Error(), "stream closed") + test.Error(t, sendErr) + test.StrContains(t, sendErr.Error(), "stream closed") }) } @@ -270,22 +270,22 @@ func TestSSEStream_Done(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) done := stream.Done() - require.NoError(t, stream.Close()) + must.NoError(t, stream.Close()) select { case <-done: // expected: channel closed case <-time.After(time.Second): - require.Fail(t, "Done() channel was not closed after Close()") + t.Fatalf("Done() channel was not closed after Close()") } }) @@ -306,12 +306,12 @@ func TestSSEStream_Done(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) // Close the client connection, which cancels the request context resp.Body.Close() @@ -321,7 +321,7 @@ func TestSSEStream_Done(T *testing.T) { case <-stream.Done(): // expected case <-time.After(2 * time.Second): - require.Fail(t, "Done() channel was not closed after client disconnect") + t.Fatalf("Done() channel was not closed after client disconnect") } }) } @@ -346,17 +346,17 @@ func TestSSEStream_Close(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) // Close should be idempotent (context.CancelFunc is safe to call multiple times) - assert.NoError(t, stream.Close()) - assert.NoError(t, stream.Close()) + test.NoError(t, stream.Close()) + test.NoError(t, stream.Close()) }) } @@ -389,28 +389,28 @@ func TestSSEStream_Send_verifies_SSE_format(T *testing.T) { defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) defer stream.Close() sendErr := stream.Send(t.Context(), &eventstream.Event{ Type: "update", Payload: json.RawMessage(`{"id":"abc","status":"done"}`), }) - require.NoError(t, sendErr) + must.NoError(t, sendErr) // Read raw bytes and verify the exact SSE format buf := make([]byte, 4096) n, readErr := resp.Body.Read(buf) - require.NoError(t, readErr) + must.NoError(t, readErr) output := string(buf[:n]) expected := "event: update\ndata: {\"id\":\"abc\",\"status\":\"done\"}\n\n" - assert.Equal(t, expected, output) + test.EqOp(t, expected, output) }) } diff --git a/eventstream/websocket/config_test.go b/eventstream/websocket/config_test.go index d5d1c03..fa887f0 100644 --- a/eventstream/websocket/config_test.go +++ b/eventstream/websocket/config_test.go @@ -3,7 +3,7 @@ package websocket import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -15,6 +15,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/eventstream/websocket/websocket_test.go b/eventstream/websocket/websocket_test.go index 6800a46..669ab1a 100644 --- a/eventstream/websocket/websocket_test.go +++ b/eventstream/websocket/websocket_test.go @@ -11,8 +11,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" gorillawebsocket "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNewUpgrader(T *testing.T) { @@ -22,10 +22,10 @@ func TestNewUpgrader(T *testing.T) { t.Parallel() u := NewUpgrader(nil, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, u) - assert.Equal(t, defaultHeartbeatInterval, u.heartbeatInterval) - assert.Equal(t, defaultBufferSize, u.wsUpgrader.ReadBufferSize) - assert.Equal(t, defaultBufferSize, u.wsUpgrader.WriteBufferSize) + must.NotNil(t, u) + test.EqOp(t, defaultHeartbeatInterval, u.heartbeatInterval) + test.EqOp(t, defaultBufferSize, u.wsUpgrader.ReadBufferSize) + test.EqOp(t, defaultBufferSize, u.wsUpgrader.WriteBufferSize) }) T.Run("custom config", func(t *testing.T) { @@ -36,10 +36,10 @@ func TestNewUpgrader(T *testing.T) { ReadBufferSize: 2048, WriteBufferSize: 4096, }) - require.NotNil(t, u) - assert.Equal(t, 10*time.Second, u.heartbeatInterval) - assert.Equal(t, 2048, u.wsUpgrader.ReadBufferSize) - assert.Equal(t, 4096, u.wsUpgrader.WriteBufferSize) + must.NotNil(t, u) + test.EqOp(t, 10*time.Second, u.heartbeatInterval) + test.EqOp(t, 2048, u.wsUpgrader.ReadBufferSize) + test.EqOp(t, 4096, u.wsUpgrader.WriteBufferSize) }) } @@ -63,11 +63,11 @@ func TestUpgrader_UpgradeToEventStream(T *testing.T) { defer server.Close() conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+server.URL[4:], http.Header{"Origin": {server.URL}}) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) defer stream.Close() }) } @@ -92,14 +92,14 @@ func TestUpgrader_UpgradeToBidirectionalStream(T *testing.T) { defer server.Close() conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+server.URL[4:], http.Header{"Origin": {server.URL}}) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() stream := <-streamReady - require.NotNil(t, stream) + must.NotNil(t, stream) defer stream.Close() - assert.NotNil(t, stream.Receive()) + test.NotNil(t, stream.Receive()) }) } @@ -128,7 +128,7 @@ func TestWSStream_Send(T *testing.T) { defer server.Close() conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+server.URL[4:], http.Header{"Origin": {server.URL}}) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() go func() { @@ -140,10 +140,10 @@ func TestWSStream_Send(T *testing.T) { select { case event := <-received: - assert.Equal(t, "test", event.Type) - assert.JSONEq(t, `{"msg":"hello"}`, string(event.Payload)) + test.EqOp(t, "test", event.Type) + test.EqOp(t, `{"msg":"hello"}`, string(event.Payload)) case <-time.After(2 * time.Second): - require.Fail(t, "did not receive event") + t.Fatalf("did not receive event") } }) @@ -163,15 +163,15 @@ func TestWSStream_Send(T *testing.T) { defer server.Close() conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+server.URL[4:], http.Header{"Origin": {server.URL}}) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() stream := <-streamReady - require.NoError(t, stream.Close()) + must.NoError(t, stream.Close()) sendErr := stream.Send(t.Context(), &eventstream.Event{Type: "x"}) - assert.Error(t, sendErr) - assert.Contains(t, sendErr.Error(), "stream closed") + test.Error(t, sendErr) + test.StrContains(t, sendErr.Error(), "stream closed") }) } @@ -194,18 +194,18 @@ func TestWSStream_Done(T *testing.T) { defer server.Close() conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+server.URL[4:], http.Header{"Origin": {server.URL}}) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() stream := <-streamReady done := stream.Done() - require.NoError(t, stream.Close()) + must.NoError(t, stream.Close()) select { case <-done: // expected case <-time.After(time.Second): - require.Fail(t, "Done() channel was not closed after Close()") + t.Fatalf("Done() channel was not closed after Close()") } }) } @@ -229,12 +229,12 @@ func TestWSStream_Close(T *testing.T) { defer server.Close() conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+server.URL[4:], http.Header{"Origin": {server.URL}}) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() stream := <-streamReady - assert.NoError(t, stream.Close()) - assert.NoError(t, stream.Close()) + test.NoError(t, stream.Close()) + test.NoError(t, stream.Close()) }) } @@ -257,7 +257,7 @@ func TestBidirectionalWSStream_Receive(T *testing.T) { defer server.Close() conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+server.URL[4:], http.Header{"Origin": {server.URL}}) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() stream := <-streamReady @@ -268,15 +268,15 @@ func TestBidirectionalWSStream_Receive(T *testing.T) { Type: "ping", Payload: json.RawMessage(`{"seq":1}`), } - require.NoError(t, conn.WriteJSON(outgoing)) + must.NoError(t, conn.WriteJSON(outgoing)) select { case event := <-stream.Receive(): - require.NotNil(t, event) - assert.Equal(t, "ping", event.Type) - assert.JSONEq(t, `{"seq":1}`, string(event.Payload)) + must.NotNil(t, event) + test.EqOp(t, "ping", event.Type) + test.EqOp(t, `{"seq":1}`, string(event.Payload)) case <-time.After(2 * time.Second): - require.Fail(t, "did not receive event from client") + t.Fatalf("did not receive event from client") } }) @@ -296,19 +296,19 @@ func TestBidirectionalWSStream_Receive(T *testing.T) { defer server.Close() conn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+server.URL[4:], http.Header{"Origin": {server.URL}}) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() stream := <-streamReady incoming := stream.Receive() - require.NoError(t, stream.Close()) + must.NoError(t, stream.Close()) select { case _, open := <-incoming: - assert.False(t, open, "Receive channel should be closed") + test.False(t, open, test.Sprintf("Receive channel should be closed")) case <-time.After(2 * time.Second): - require.Fail(t, "Receive channel was not closed after stream.Close()") + t.Fatalf("Receive channel was not closed after stream.Close()") } }) } diff --git a/fake/fake.go b/fake/fake.go index 757f133..b50e0fe 100644 --- a/fake/fake.go +++ b/fake/fake.go @@ -7,7 +7,7 @@ import ( fake "github.com/brianvoe/gofakeit/v7" "github.com/go-faker/faker/v4" "github.com/go-faker/faker/v4/pkg/options" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) // BuildFakeTime builds a fake time. @@ -18,7 +18,7 @@ func BuildFakeTime() time.Time { // BuildFakeForTest builds a fake instance of insert-struct-here for a test. func BuildFakeForTest[X any](t *testing.T) (x *X) { t.Helper() - require.NoError(t, faker.FakeData(&x, options.WithRecursionMaxDepth(0))) + must.NoError(t, faker.FakeData(&x, options.WithRecursionMaxDepth(0))) return x } diff --git a/fake/fake_test.go b/fake/fake_test.go index f23722c..b0d084e 100644 --- a/fake/fake_test.go +++ b/fake/fake_test.go @@ -3,7 +3,7 @@ package fake import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestBuildFkaeTime(T *testing.T) { @@ -14,7 +14,7 @@ func TestBuildFkaeTime(T *testing.T) { actual := BuildFakeTime() - assert.NotNil(t, actual) + test.False(t, actual.IsZero()) }) } @@ -30,7 +30,7 @@ func TestBuildFakeForTest(T *testing.T) { t.Parallel() actual := BuildFakeForTest[*example](t) - assert.NotNil(t, actual) + test.NotNil(t, actual) }) } @@ -40,17 +40,17 @@ func TestMustBuildFake(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotPanics(t, func() { + test.NotPanic(t, func() { actual := MustBuildFake[example]() - assert.NotEmpty(t, actual.Name) - assert.NotEmpty(t, actual.Age) + test.NotEq(t, "", actual.Name) + test.NotEq(t, 0, actual.Age) }) }) T.Run("standard", func(t *testing.T) { t.Parallel() - assert.Panics(t, func() { + test.Panic(t, func() { MustBuildFake[any]() }) }) @@ -63,15 +63,15 @@ func TestBuildFake(T *testing.T) { t.Parallel() actual, err := BuildFake[string]() - assert.NoError(t, err) - assert.NotEmpty(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) T.Run("with error", func(t *testing.T) { t.Parallel() actual, err := BuildFake[any]() - assert.Error(t, err) - assert.Empty(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) } diff --git a/featureflags/config/config_test.go b/featureflags/config/config_test.go index f756dde..fe70cdd 100644 --- a/featureflags/config/config_test.go +++ b/featureflags/config/config_test.go @@ -16,8 +16,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -36,7 +35,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderLaunchDarkly, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with empty provider for noop", func(t *testing.T) { @@ -47,7 +46,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "", } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with invalid provider", func(t *testing.T) { @@ -58,7 +57,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "invalid_provider", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with posthog provider", func(t *testing.T) { @@ -73,7 +72,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderPostHog, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with launchdarkly provider missing config", func(t *testing.T) { @@ -84,7 +83,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderLaunchDarkly, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with posthog provider missing config", func(t *testing.T) { @@ -95,7 +94,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderPostHog, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -121,8 +120,8 @@ func TestConfig_ProvideFeatureFlagManager(T *testing.T) { } ffm, err := cfg.ProvideFeatureFlagManager(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, http.DefaultClient, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, ffm) + must.NoError(t, err) + must.NotNil(t, ffm) }) T.Run("with unknown provider returns noop", func(t *testing.T) { @@ -133,8 +132,8 @@ func TestConfig_ProvideFeatureFlagManager(T *testing.T) { } ffm, err := cfg.ProvideFeatureFlagManager(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, http.DefaultClient, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, ffm) + must.NoError(t, err) + must.NotNil(t, ffm) }) T.Run("with launchdarkly provider but nil config", func(t *testing.T) { @@ -145,8 +144,8 @@ func TestConfig_ProvideFeatureFlagManager(T *testing.T) { } ffm, err := cfg.ProvideFeatureFlagManager(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, http.DefaultClient, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, ffm) + must.Error(t, err) + must.Nil(t, ffm) }) T.Run("with posthog provider but nil config", func(t *testing.T) { @@ -157,8 +156,8 @@ func TestConfig_ProvideFeatureFlagManager(T *testing.T) { } ffm, err := cfg.ProvideFeatureFlagManager(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, http.DefaultClient, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, ffm) + must.Error(t, err) + must.Nil(t, ffm) }) T.Run("with provider string that has whitespace and mixed case", func(t *testing.T) { @@ -170,8 +169,8 @@ func TestConfig_ProvideFeatureFlagManager(T *testing.T) { // Will fail because LaunchDarkly config is nil, but proves the normalization works ffm, err := cfg.ProvideFeatureFlagManager(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, http.DefaultClient, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, ffm) + must.Error(t, err) + must.Nil(t, ffm) }) } @@ -187,8 +186,8 @@ func TestProvideFeatureFlagManager(T *testing.T) { } ffm, err := ProvideFeatureFlagManager(ctx, cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), http.DefaultClient) - require.NoError(t, err) - require.NotNil(t, ffm) + must.NoError(t, err) + must.NotNil(t, ffm) }) T.Run("with circuit breaker error", func(t *testing.T) { @@ -209,8 +208,8 @@ func TestProvideFeatureFlagManager(T *testing.T) { } ffm, err := ProvideFeatureFlagManager(ctx, cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, http.DefaultClient) - require.Error(t, err) - require.Nil(t, ffm) + must.Error(t, err) + must.Nil(t, ffm) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) diff --git a/featureflags/config/do_test.go b/featureflags/config/do_test.go index de8fdf0..633b15e 100644 --- a/featureflags/config/do_test.go +++ b/featureflags/config/do_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterFeatureFlagManager(T *testing.T) { @@ -31,7 +31,7 @@ func TestRegisterFeatureFlagManager(T *testing.T) { RegisterFeatureFlagManager(i) ffm, err := do.Invoke[featureflags.FeatureFlagManager](i) - require.NoError(t, err) - assert.NotNil(t, ffm) + must.NoError(t, err) + test.NotNil(t, ffm) }) } diff --git a/featureflags/launchdarkly/feature_flag_manager_test.go b/featureflags/launchdarkly/feature_flag_manager_test.go index 487c0ea..14a5862 100644 --- a/featureflags/launchdarkly/feature_flag_manager_test.go +++ b/featureflags/launchdarkly/feature_flag_manager_test.go @@ -24,8 +24,8 @@ import ( "github.com/launchdarkly/go-server-sdk/v6/subsystems/ldstoretypes" ofld "github.com/open-feature/go-sdk-contrib/providers/launchdarkly/pkg" "github.com/open-feature/go-sdk/openfeature" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func evalCtx(targetingKey string) featureflags.EvaluationContext { @@ -112,8 +112,8 @@ func buildTestManager(t *testing.T, cb circuitbreaking.CircuitBreaker) *featureF config.DataSource = &fakeLaunchDarklyDataSourceBuilder{} return config }) - require.NoError(t, err) - require.NotNil(t, ffm) + must.NoError(t, err) + must.NotNil(t, ffm) return ffm.(*featureFlagManager) } @@ -127,24 +127,24 @@ func buildTestManagerWithFlags(t *testing.T, flags []ldstoretypes.KeyedItemDescr } client, err := ld.MakeCustomClient(t.Name(), ldConfig, 5*time.Second) - require.NoError(t, err) + must.NoError(t, err) t.Cleanup(func() { client.Close() }) // Use a unique domain per test to avoid global OpenFeature provider conflicts. domain := "test_" + strings.ReplaceAll(t.Name(), "/", "_") provider := ofld.NewProvider(client) err = openfeature.SetNamedProviderAndWait(domain, provider) - require.NoError(t, err) + must.NoError(t, err) ofClient := openfeature.NewClient(domain) mp := metrics.EnsureMetricsProvider(nil) evalCounter, err := mp.NewInt64Counter(fmt.Sprintf("%s_evaluations", serviceName)) - require.NoError(t, err) + must.NoError(t, err) errorCounter, err := mp.NewInt64Counter(fmt.Sprintf("%s_errors", serviceName)) - require.NoError(t, err) + must.NoError(t, err) latencyHist, err := mp.NewFloat64Histogram(fmt.Sprintf("%s_latency_ms", serviceName)) - require.NoError(t, err) + must.NoError(t, err) return &featureFlagManager{ ldClient: client, @@ -170,8 +170,8 @@ func TestNewFeatureFlagManager(T *testing.T) { config.DataSource = &fakeLaunchDarklyDataSourceBuilder{} return config }) - require.NoError(t, err) - require.NotNil(t, actual) + must.NoError(t, err) + must.NotNil(t, actual) }) T.Run("with missing http client", func(t *testing.T) { @@ -180,16 +180,16 @@ func TestNewFeatureFlagManager(T *testing.T) { cfg := &Config{SDKKey: t.Name()} actual, err := NewFeatureFlagManager(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, nil, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, actual) + must.Error(t, err) + must.Nil(t, actual) }) T.Run("with nil config", func(t *testing.T) { t.Parallel() actual, err := NewFeatureFlagManager(nil, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, http.DefaultClient, cbnoop.NewCircuitBreaker()) - require.Error(t, err) - require.Nil(t, actual) + must.Error(t, err) + must.Nil(t, actual) }) T.Run("with missing SDK key", func(t *testing.T) { @@ -201,8 +201,8 @@ func TestNewFeatureFlagManager(T *testing.T) { config.DataSource = &fakeLaunchDarklyDataSourceBuilder{} return config }) - require.Error(t, err) - require.Nil(t, actual) + must.Error(t, err) + must.Nil(t, actual) }) T.Run("with zero init timeout gets default", func(t *testing.T) { @@ -214,8 +214,8 @@ func TestNewFeatureFlagManager(T *testing.T) { config.DataSource = &fakeLaunchDarklyDataSourceBuilder{} return config }) - require.NoError(t, err) - require.NotNil(t, actual) + must.NoError(t, err) + must.NotNil(t, actual) }) } @@ -232,9 +232,9 @@ func TestToOpenFeatureContext(T *testing.T) { result := toOpenFeatureContext(ec) - assert.Equal(t, "user123", result.TargetingKey()) - assert.Equal(t, "pro", result.Attribute("plan")) - assert.Equal(t, "us-east", result.Attribute("region")) + test.EqOp(t, "user123", result.TargetingKey()) + test.Eq(t, "pro", result.Attribute("plan")) + test.Eq(t, "us-east", result.Attribute("region")) }) T.Run("with nil attributes", func(t *testing.T) { @@ -246,7 +246,7 @@ func TestToOpenFeatureContext(T *testing.T) { result := toOpenFeatureContext(ec) - assert.Equal(t, "user456", result.TargetingKey()) + test.EqOp(t, "user456", result.TargetingKey()) }) } @@ -260,8 +260,8 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { ffm := buildTestManagerWithFlags(t, testFlagItems()) result, err := ffm.CanUseFeature(ctx, "bool-flag", evalCtx("user123")) - assert.NoError(t, err) - assert.True(t, result) + test.NoError(t, err) + test.True(t, result) }) T.Run("with flag not found", func(t *testing.T) { @@ -277,8 +277,8 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.CanUseFeature(ctx, "nonexistent-flag", evalCtx("user123")) - assert.Error(t, err) - assert.False(t, result) + test.Error(t, err) + test.False(t, result) }) T.Run("with broken circuit", func(t *testing.T) { @@ -292,8 +292,8 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.CanUseFeature(ctx, "some-flag", evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.False(t, result) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.False(t, result) }) } @@ -307,8 +307,8 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { ffm := buildTestManagerWithFlags(t, testFlagItems()) result, err := ffm.GetStringValue(ctx, "string-flag", "fallback", evalCtx("user123")) - assert.NoError(t, err) - assert.Equal(t, "hello-world", result) + test.NoError(t, err) + test.EqOp(t, "hello-world", result) }) T.Run("with flag not found", func(t *testing.T) { @@ -324,8 +324,8 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetStringValue(ctx, "nonexistent-flag", "fallback", evalCtx("user123")) - assert.Error(t, err) - assert.Equal(t, "fallback", result) + test.Error(t, err) + test.EqOp(t, "fallback", result) }) T.Run("with broken circuit", func(t *testing.T) { @@ -339,8 +339,8 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetStringValue(ctx, "some-flag", "fallback", evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.Equal(t, "fallback", result) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.EqOp(t, "fallback", result) }) } @@ -354,8 +354,8 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { ffm := buildTestManagerWithFlags(t, testFlagItems()) result, err := ffm.GetInt64Value(ctx, "int-flag", int64(0), evalCtx("user123")) - assert.NoError(t, err) - assert.Equal(t, int64(42), result) + test.NoError(t, err) + test.EqOp(t, int64(42), result) }) T.Run("with flag not found", func(t *testing.T) { @@ -371,8 +371,8 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetInt64Value(ctx, "nonexistent-flag", int64(42), evalCtx("user123")) - assert.Error(t, err) - assert.Equal(t, int64(42), result) + test.Error(t, err) + test.EqOp(t, int64(42), result) }) T.Run("with broken circuit", func(t *testing.T) { @@ -386,8 +386,8 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetInt64Value(ctx, "some-flag", int64(42), evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.Equal(t, int64(42), result) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.EqOp(t, int64(42), result) }) } @@ -401,8 +401,8 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { ffm := buildTestManagerWithFlags(t, testFlagItems()) result, err := ffm.GetFloat64Value(ctx, "float-flag", 0.0, evalCtx("user123")) - assert.NoError(t, err) - assert.InDelta(t, 3.14, result, 1e-9) + test.NoError(t, err) + test.InDelta(t, 3.14, result, 1e-9) }) T.Run("with flag not found", func(t *testing.T) { @@ -418,8 +418,8 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetFloat64Value(ctx, "nonexistent-flag", 3.14, evalCtx("user123")) - assert.Error(t, err) - assert.InDelta(t, 3.14, result, 1e-9) + test.Error(t, err) + test.InDelta(t, 3.14, result, 1e-9) }) T.Run("with broken circuit", func(t *testing.T) { @@ -433,8 +433,8 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetFloat64Value(ctx, "some-flag", 3.14, evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.InDelta(t, 3.14, result, 1e-9) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.InDelta(t, 3.14, result, 1e-9) }) } @@ -449,8 +449,8 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { def := map[string]any{"default": true} result, err := ffm.GetObjectValue(ctx, "object-flag", def, evalCtx("user123")) - assert.NoError(t, err) - assert.Equal(t, map[string]any{"key": "value"}, result) + test.NoError(t, err) + test.Eq[any](t, map[string]any{"key": "value"}, result) }) T.Run("with flag not found", func(t *testing.T) { @@ -467,8 +467,8 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { def := map[string]any{"k": "v"} result, err := ffm.GetObjectValue(ctx, "nonexistent-flag", def, evalCtx("user123")) - assert.Error(t, err) - assert.Equal(t, def, result) + test.Error(t, err) + test.Eq[any](t, def, result) }) T.Run("with broken circuit", func(t *testing.T) { @@ -483,8 +483,8 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { def := map[string]any{"k": "v"} result, err := ffm.GetObjectValue(ctx, "some-flag", def, evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.Equal(t, def, result) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.Eq[any](t, def, result) }) } @@ -497,6 +497,6 @@ func TestFeatureFlagManager_Close(T *testing.T) { ffm := buildTestManager(t, cbnoop.NewCircuitBreaker()) err := ffm.Close() - assert.NoError(t, err) + test.NoError(t, err) }) } diff --git a/featureflags/noop/noop_test.go b/featureflags/noop/noop_test.go index 3b2db0d..dd822f0 100644 --- a/featureflags/noop/noop_test.go +++ b/featureflags/noop/noop_test.go @@ -6,8 +6,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/featureflags" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func evalCtx() featureflags.EvaluationContext { @@ -21,7 +21,7 @@ func TestNewFeatureFlagManager(T *testing.T) { t.Parallel() mgr := NewFeatureFlagManager() - require.NotNil(t, mgr) + must.NotNil(t, mgr) }) } @@ -32,8 +32,8 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { t.Parallel() result, err := NewFeatureFlagManager().CanUseFeature(context.Background(), "some-feature", evalCtx()) - assert.NoError(t, err) - assert.False(t, result) + test.NoError(t, err) + test.False(t, result) }) } @@ -44,8 +44,8 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { t.Parallel() result, err := NewFeatureFlagManager().GetStringValue(context.Background(), "some-feature", "fallback", evalCtx()) - assert.NoError(t, err) - assert.Equal(t, "fallback", result) + test.NoError(t, err) + test.EqOp(t, "fallback", result) }) } @@ -56,8 +56,8 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { t.Parallel() result, err := NewFeatureFlagManager().GetInt64Value(context.Background(), "some-feature", int64(42), evalCtx()) - assert.NoError(t, err) - assert.Equal(t, int64(42), result) + test.NoError(t, err) + test.EqOp(t, int64(42), result) }) } @@ -68,8 +68,8 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { t.Parallel() result, err := NewFeatureFlagManager().GetFloat64Value(context.Background(), "some-feature", 3.14, evalCtx()) - assert.NoError(t, err) - assert.InDelta(t, 3.14, result, 1e-9) + test.NoError(t, err) + test.InDelta(t, 3.14, result, 1e-9) }) } @@ -81,8 +81,8 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { def := map[string]any{"k": "v"} result, err := NewFeatureFlagManager().GetObjectValue(context.Background(), "some-feature", def, evalCtx()) - assert.NoError(t, err) - assert.Equal(t, def, result) + test.NoError(t, err) + test.Eq[any](t, def, result) }) } @@ -93,6 +93,6 @@ func TestFeatureFlagManager_Close(T *testing.T) { t.Parallel() err := NewFeatureFlagManager().Close() - assert.NoError(t, err) + test.NoError(t, err) }) } diff --git a/featureflags/posthog/feature_flag_manager_test.go b/featureflags/posthog/feature_flag_manager_test.go index b8884ca..0bc272c 100644 --- a/featureflags/posthog/feature_flag_manager_test.go +++ b/featureflags/posthog/feature_flag_manager_test.go @@ -19,8 +19,8 @@ import ( openfeatureposthog "github.com/dhaus67/openfeature-posthog-go" "github.com/open-feature/go-sdk/openfeature" "github.com/posthog/posthog-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) var testFlags = map[string]any{ @@ -82,8 +82,8 @@ func buildTestManager(t *testing.T, cb circuitbreaking.CircuitBreaker, configMod } ffm, err := NewFeatureFlagManager(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cb, configModifiers...) - require.NoError(t, err) - require.NotNil(t, ffm) + must.NoError(t, err) + must.NotNil(t, ffm) return ffm.(*featureFlagManager) } @@ -99,7 +99,7 @@ func buildTestManagerWithHandler(t *testing.T, handler http.Handler) *featureFla } client, err := posthog.NewWithConfig(t.Name(), phConfig) - require.NoError(t, err) + must.NoError(t, err) t.Cleanup(func() { client.Close() @@ -110,17 +110,17 @@ func buildTestManagerWithHandler(t *testing.T, handler http.Handler) *featureFla domain := "test_" + strings.ReplaceAll(t.Name(), "/", "_") provider := openfeatureposthog.NewProvider(client) err = openfeature.SetNamedProviderAndWait(domain, provider) - require.NoError(t, err) + must.NoError(t, err) ofClient := openfeature.NewClient(domain) mp := metrics.EnsureMetricsProvider(nil) evalCounter, err := mp.NewInt64Counter(fmt.Sprintf("%s_evaluations", serviceName)) - require.NoError(t, err) + must.NoError(t, err) errorCounter, err := mp.NewInt64Counter(fmt.Sprintf("%s_errors", serviceName)) - require.NoError(t, err) + must.NoError(t, err) latencyHist, err := mp.NewFloat64Histogram(fmt.Sprintf("%s_latency_ms", serviceName)) - require.NoError(t, err) + must.NoError(t, err) return &featureFlagManager{ posthogClient: client, @@ -146,16 +146,16 @@ func TestNewFeatureFlagManager(T *testing.T) { } actual, err := NewFeatureFlagManager(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker()) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) T.Run("with nil config", func(t *testing.T) { t.Parallel() actual, err := NewFeatureFlagManager(nil, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) T.Run("with missing project API key", func(t *testing.T) { @@ -164,8 +164,8 @@ func TestNewFeatureFlagManager(T *testing.T) { cfg := &Config{} actual, err := NewFeatureFlagManager(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) T.Run("with missing personal API key", func(t *testing.T) { @@ -176,8 +176,8 @@ func TestNewFeatureFlagManager(T *testing.T) { } actual, err := NewFeatureFlagManager(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) T.Run("with invalid config", func(t *testing.T) { @@ -191,8 +191,8 @@ func TestNewFeatureFlagManager(T *testing.T) { actual, err := NewFeatureFlagManager(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cbnoop.NewCircuitBreaker(), func(config *posthog.Config) { config.Interval = -1 }) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) } @@ -209,9 +209,9 @@ func TestToOpenFeatureContext(T *testing.T) { result := toOpenFeatureContext(ec) - assert.Equal(t, "user123", result.TargetingKey()) - assert.Equal(t, "pro", result.Attribute("plan")) - assert.Equal(t, "us-east", result.Attribute("region")) + test.EqOp(t, "user123", result.TargetingKey()) + test.Eq(t, "pro", result.Attribute("plan")) + test.Eq(t, "us-east", result.Attribute("region")) }) T.Run("with nil attributes", func(t *testing.T) { @@ -223,7 +223,7 @@ func TestToOpenFeatureContext(T *testing.T) { result := toOpenFeatureContext(ec) - assert.Equal(t, "user456", result.TargetingKey()) + test.EqOp(t, "user456", result.TargetingKey()) }) } @@ -237,8 +237,8 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { ffm := buildTestManagerWithHandler(t, posthogFlagsHandler(testFlags)) actual, err := ffm.CanUseFeature(ctx, "bool-flag", evalCtx("user123")) - assert.NoError(t, err) - assert.True(t, actual) + test.NoError(t, err) + test.True(t, actual) }) T.Run("with error executing request", func(t *testing.T) { @@ -248,8 +248,8 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { ffm := buildTestManagerWithHandler(t, posthogErrorHandler()) actual, err := ffm.CanUseFeature(ctx, "bool-flag", evalCtx("user123")) - assert.Error(t, err) - assert.False(t, actual) + test.Error(t, err) + test.False(t, actual) }) T.Run("with broken circuit", func(t *testing.T) { @@ -263,8 +263,8 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.CanUseFeature(ctx, "some-flag", evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.False(t, result) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.False(t, result) }) } @@ -278,8 +278,8 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { ffm := buildTestManagerWithHandler(t, posthogFlagsHandler(testFlags)) result, err := ffm.GetStringValue(ctx, "string-flag", "fallback", evalCtx("user123")) - assert.NoError(t, err) - assert.Equal(t, "hello-world", result) + test.NoError(t, err) + test.EqOp(t, "hello-world", result) }) T.Run("with error executing request", func(t *testing.T) { @@ -289,8 +289,8 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { ffm := buildTestManagerWithHandler(t, posthogErrorHandler()) result, err := ffm.GetStringValue(ctx, "string-flag", "fallback", evalCtx("user123")) - assert.Error(t, err) - assert.Equal(t, "fallback", result) + test.Error(t, err) + test.EqOp(t, "fallback", result) }) T.Run("with broken circuit", func(t *testing.T) { @@ -304,8 +304,8 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetStringValue(ctx, "some-flag", "fallback", evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.Equal(t, "fallback", result) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.EqOp(t, "fallback", result) }) } @@ -319,8 +319,8 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { ffm := buildTestManagerWithHandler(t, posthogFlagsHandler(testFlags)) result, err := ffm.GetInt64Value(ctx, "int-flag", int64(0), evalCtx("user123")) - assert.NoError(t, err) - assert.Equal(t, int64(42), result) + test.NoError(t, err) + test.EqOp(t, int64(42), result) }) T.Run("with error executing request", func(t *testing.T) { @@ -330,8 +330,8 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { ffm := buildTestManagerWithHandler(t, posthogErrorHandler()) result, err := ffm.GetInt64Value(ctx, "int-flag", int64(42), evalCtx("user123")) - assert.Error(t, err) - assert.Equal(t, int64(42), result) + test.Error(t, err) + test.EqOp(t, int64(42), result) }) T.Run("with broken circuit", func(t *testing.T) { @@ -345,8 +345,8 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetInt64Value(ctx, "some-flag", int64(42), evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.Equal(t, int64(42), result) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.EqOp(t, int64(42), result) }) } @@ -360,8 +360,8 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { ffm := buildTestManagerWithHandler(t, posthogFlagsHandler(testFlags)) result, err := ffm.GetFloat64Value(ctx, "float-flag", 0.0, evalCtx("user123")) - assert.NoError(t, err) - assert.InDelta(t, 3.14, result, 1e-9) + test.NoError(t, err) + test.InDelta(t, 3.14, result, 1e-9) }) T.Run("with error executing request", func(t *testing.T) { @@ -371,8 +371,8 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { ffm := buildTestManagerWithHandler(t, posthogErrorHandler()) result, err := ffm.GetFloat64Value(ctx, "float-flag", 3.14, evalCtx("user123")) - assert.Error(t, err) - assert.InDelta(t, 3.14, result, 1e-9) + test.Error(t, err) + test.InDelta(t, 3.14, result, 1e-9) }) T.Run("with broken circuit", func(t *testing.T) { @@ -386,8 +386,8 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { ffm := buildTestManager(t, cb) result, err := ffm.GetFloat64Value(ctx, "some-flag", 3.14, evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.InDelta(t, 3.14, result, 1e-9) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.InDelta(t, 3.14, result, 1e-9) }) } @@ -402,8 +402,8 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { def := map[string]any{"default": true} result, err := ffm.GetObjectValue(ctx, "object-flag", def, evalCtx("user123")) - assert.NoError(t, err) - assert.Equal(t, map[string]any{"key": "value"}, result) + test.NoError(t, err) + test.Eq[any](t, map[string]any{"key": "value"}, result) }) T.Run("with error executing request", func(t *testing.T) { @@ -414,8 +414,8 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { def := map[string]any{"k": "v"} result, err := ffm.GetObjectValue(ctx, "object-flag", def, evalCtx("user123")) - assert.Error(t, err) - assert.Equal(t, def, result) + test.Error(t, err) + test.Eq[any](t, def, result) }) T.Run("with broken circuit", func(t *testing.T) { @@ -430,8 +430,8 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { def := map[string]any{"k": "v"} result, err := ffm.GetObjectValue(ctx, "some-flag", def, evalCtx("user123")) - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.Equal(t, def, result) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.Eq[any](t, def, result) }) } @@ -444,6 +444,6 @@ func TestFeatureFlagManager_Close(T *testing.T) { ffm := buildTestManager(t, cbnoop.NewCircuitBreaker()) err := ffm.Close() - assert.NoError(t, err) + test.NoError(t, err) }) } diff --git a/go.mod b/go.mod index a9f6108..2798d7c 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,6 @@ require ( github.com/sendgrid/sendgrid-go v3.16.1+incompatible github.com/shoenig/test v1.12.2 github.com/sideshow/apns2 v0.25.0 - github.com/stretchr/testify v1.11.1 github.com/stripe/stripe-go/v75 v75.11.0 github.com/testcontainers/testcontainers-go v0.41.0 github.com/testcontainers/testcontainers-go/modules/elasticsearch v0.41.0 @@ -119,6 +118,7 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/stretchr/testify v1.11.1 // indirect github.com/x448/float16 v0.8.4 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/term v0.41.0 // indirect diff --git a/healthcheck/checkers_test.go b/healthcheck/checkers_test.go index 1674bb7..d28c07e 100644 --- a/healthcheck/checkers_test.go +++ b/healthcheck/checkers_test.go @@ -2,12 +2,15 @@ package healthcheck import ( "context" + "errors" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) +var errStub = errors.New("stub error") + func TestNewDatabaseChecker(T *testing.T) { T.Parallel() @@ -18,9 +21,9 @@ func TestNewDatabaseChecker(T *testing.T) { checker := NewDatabaseChecker("postgres", client) ctx := context.Background() - assert.Equal(t, "postgres", checker.Name()) + test.EqOp(t, "postgres", checker.Name()) err := checker.Check(ctx) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("not ready", func(t *testing.T) { @@ -31,7 +34,7 @@ func TestNewDatabaseChecker(T *testing.T) { ctx := context.Background() err := checker.Check(ctx) - require.Error(t, err) + must.Error(t, err) }) T.Run("nil client", func(t *testing.T) { @@ -41,8 +44,8 @@ func TestNewDatabaseChecker(T *testing.T) { ctx := context.Background() err := checker.Check(ctx) - require.Error(t, err) - assert.Contains(t, err.Error(), "nil") + must.Error(t, err) + test.StrContains(t, err.Error(), "nil") }) } @@ -64,20 +67,20 @@ func TestNewCacheChecker(T *testing.T) { checker := NewCacheChecker("redis", client) ctx := context.Background() - assert.Equal(t, "redis", checker.Name()) + test.EqOp(t, "redis", checker.Name()) err := checker.Check(ctx) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("not ready", func(t *testing.T) { t.Parallel() - client := &mockCacheClient{err: assert.AnError} + client := &mockCacheClient{err: errStub} checker := NewCacheChecker("redis", client) ctx := context.Background() err := checker.Check(ctx) - require.Error(t, err) + must.Error(t, err) }) T.Run("nil client", func(t *testing.T) { @@ -87,8 +90,8 @@ func TestNewCacheChecker(T *testing.T) { ctx := context.Background() err := checker.Check(ctx) - require.Error(t, err) - assert.Contains(t, err.Error(), "nil") + must.Error(t, err) + test.StrContains(t, err.Error(), "nil") }) } @@ -110,20 +113,20 @@ func TestNewMessageQueueChecker(T *testing.T) { checker := NewMessageQueueChecker("redis", client) ctx := context.Background() - assert.Equal(t, "redis", checker.Name()) + test.EqOp(t, "redis", checker.Name()) err := checker.Check(ctx) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("not ready", func(t *testing.T) { t.Parallel() - client := &mockMQClient{err: assert.AnError} + client := &mockMQClient{err: errStub} checker := NewMessageQueueChecker("redis", client) ctx := context.Background() err := checker.Check(ctx) - require.Error(t, err) + must.Error(t, err) }) T.Run("nil client", func(t *testing.T) { @@ -133,8 +136,8 @@ func TestNewMessageQueueChecker(T *testing.T) { ctx := context.Background() err := checker.Check(ctx) - require.Error(t, err) - assert.Contains(t, err.Error(), "nil") + must.Error(t, err) + test.StrContains(t, err.Error(), "nil") }) } diff --git a/healthcheck/healthcheck_test.go b/healthcheck/healthcheck_test.go index 0e41955..90801da 100644 --- a/healthcheck/healthcheck_test.go +++ b/healthcheck/healthcheck_test.go @@ -5,8 +5,8 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type mockChecker struct { @@ -36,9 +36,9 @@ func TestRegistry_CheckAll(T *testing.T) { result := reg.CheckAll(ctx) - require.NotNil(t, result) - assert.Equal(t, StatusUp, result.Status) - assert.Empty(t, result.Components) + must.NotNil(t, result) + test.EqOp(t, StatusUp, result.Status) + test.MapEmpty(t, result.Components) }) T.Run("all checkers up", func(t *testing.T) { @@ -51,11 +51,11 @@ func TestRegistry_CheckAll(T *testing.T) { result := reg.CheckAll(ctx) - require.NotNil(t, result) - assert.Equal(t, StatusUp, result.Status) - assert.Len(t, result.Components, 2) - assert.Equal(t, ComponentResult{Status: StatusUp}, result.Components["a"]) - assert.Equal(t, ComponentResult{Status: StatusUp}, result.Components["b"]) + must.NotNil(t, result) + test.EqOp(t, StatusUp, result.Status) + test.MapLen(t, 2, result.Components) + test.EqOp(t, ComponentResult{Status: StatusUp}, result.Components["a"]) + test.EqOp(t, ComponentResult{Status: StatusUp}, result.Components["b"]) }) T.Run("one checker down", func(t *testing.T) { @@ -73,11 +73,11 @@ func TestRegistry_CheckAll(T *testing.T) { result := reg.CheckAll(ctx) - require.NotNil(t, result) - assert.Equal(t, StatusDown, result.Status) - assert.Len(t, result.Components, 2) - assert.Equal(t, ComponentResult{Status: StatusUp}, result.Components["up"]) - assert.Equal(t, ComponentResult{Status: StatusDown, Message: "connection refused"}, result.Components["down"]) + must.NotNil(t, result) + test.EqOp(t, StatusDown, result.Status) + test.MapLen(t, 2, result.Components) + test.EqOp(t, ComponentResult{Status: StatusUp}, result.Components["up"]) + test.EqOp(t, ComponentResult{Status: StatusDown, Message: "connection refused"}, result.Components["down"]) }) T.Run("ignores nil checker", func(t *testing.T) { @@ -90,8 +90,8 @@ func TestRegistry_CheckAll(T *testing.T) { result := reg.CheckAll(ctx) - require.NotNil(t, result) - assert.Equal(t, StatusUp, result.Status) - assert.Len(t, result.Components, 1) + must.NotNil(t, result) + test.EqOp(t, StatusUp, result.Status) + test.MapLen(t, 1, result.Components) }) } diff --git a/httpclient/client_test.go b/httpclient/client_test.go index 9bd3204..2deeacb 100644 --- a/httpclient/client_test.go +++ b/httpclient/client_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_BuildClient(T *testing.T) { @@ -22,9 +22,9 @@ func TestConfig_BuildClient(T *testing.T) { cfg.EnsureDefaults() client := cfg.BuildClient() - require.NotNil(t, client) - assert.Equal(t, 2*time.Second, client.Timeout) - assert.NotNil(t, client.Transport) + must.NotNil(t, client) + test.EqOp(t, 2*time.Second, client.Timeout) + test.NotNil(t, client.Transport) }) T.Run("with tracing disabled", func(t *testing.T) { @@ -37,9 +37,9 @@ func TestConfig_BuildClient(T *testing.T) { cfg.EnsureDefaults() client := cfg.BuildClient() - require.NotNil(t, client) - assert.Equal(t, 3*time.Second, client.Timeout) - assert.NotNil(t, client.Transport) + must.NotNil(t, client) + test.EqOp(t, 3*time.Second, client.Timeout) + test.NotNil(t, client.Transport) }) T.Run("applies MaxIdleConns and MaxIdleConnsPerHost", func(t *testing.T) { @@ -54,12 +54,12 @@ func TestConfig_BuildClient(T *testing.T) { cfg.EnsureDefaults() client := cfg.BuildClient() - require.NotNil(t, client) + must.NotNil(t, client) transport, ok := client.Transport.(*http.Transport) - require.True(t, ok) - assert.Equal(t, 42, transport.MaxIdleConns) - assert.Equal(t, 21, transport.MaxIdleConnsPerHost) + must.True(t, ok) + test.EqOp(t, 42, transport.MaxIdleConns) + test.EqOp(t, 21, transport.MaxIdleConnsPerHost) }) } @@ -70,8 +70,8 @@ func TestProvideHTTPClient(T *testing.T) { t.Parallel() client := ProvideHTTPClient(nil) - require.NotNil(t, client) - assert.Equal(t, defaultTimeout, client.Timeout) + must.NotNil(t, client) + test.EqOp(t, defaultTimeout, client.Timeout) }) T.Run("with config uses config values", func(t *testing.T) { @@ -81,7 +81,7 @@ func TestProvideHTTPClient(T *testing.T) { Timeout: 7 * time.Second, } client := ProvideHTTPClient(cfg) - require.NotNil(t, client) - assert.Equal(t, 7*time.Second, client.Timeout) + must.NotNil(t, client) + test.EqOp(t, 7*time.Second, client.Timeout) }) } diff --git a/httpclient/config_test.go b/httpclient/config_test.go index 63b9125..4454e51 100644 --- a/httpclient/config_test.go +++ b/httpclient/config_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_EnsureDefaults(T *testing.T) { @@ -18,9 +18,9 @@ func TestConfig_EnsureDefaults(T *testing.T) { cfg := &Config{} cfg.EnsureDefaults() - assert.Equal(t, defaultTimeout, cfg.Timeout) - assert.Equal(t, defaultMaxIdleConns, cfg.MaxIdleConns) - assert.Equal(t, defaultMaxIdleConnsPerHost, cfg.MaxIdleConnsPerHost) + test.EqOp(t, defaultTimeout, cfg.Timeout) + test.EqOp(t, defaultMaxIdleConns, cfg.MaxIdleConns) + test.EqOp(t, defaultMaxIdleConnsPerHost, cfg.MaxIdleConnsPerHost) }) T.Run("preserves non-zero values", func(t *testing.T) { @@ -33,9 +33,9 @@ func TestConfig_EnsureDefaults(T *testing.T) { } cfg.EnsureDefaults() - assert.Equal(t, 5*time.Second, cfg.Timeout) - assert.Equal(t, 50, cfg.MaxIdleConns) - assert.Equal(t, 25, cfg.MaxIdleConnsPerHost) + test.EqOp(t, 5*time.Second, cfg.Timeout) + test.EqOp(t, 50, cfg.MaxIdleConns) + test.EqOp(t, 25, cfg.MaxIdleConnsPerHost) }) } @@ -53,7 +53,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("invalid timeout", func(t *testing.T) { @@ -67,7 +67,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid MaxIdleConns", func(t *testing.T) { @@ -81,7 +81,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid MaxIdleConnsPerHost", func(t *testing.T) { @@ -95,6 +95,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.Error(t, err) + must.Error(t, err) }) } diff --git a/identifiers/identifiers_test.go b/identifiers/identifiers_test.go index 535d6f5..69c45c8 100644 --- a/identifiers/identifiers_test.go +++ b/identifiers/identifiers_test.go @@ -4,7 +4,7 @@ import ( "testing" "github.com/rs/xid" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestNew(T *testing.T) { @@ -14,7 +14,7 @@ func TestNew(T *testing.T) { t.Parallel() actual := New() - assert.NotEmpty(t, actual) + test.NotEq(t, "", actual) }) } @@ -25,6 +25,6 @@ func TestValidate(T *testing.T) { t.Parallel() actual := Validate(xid.New().String()) - assert.NoError(t, actual) + test.NoError(t, actual) }) } diff --git a/llm/anthropic/anthropic_test.go b/llm/anthropic/anthropic_test.go index 3b8b363..25feafb 100644 --- a/llm/anthropic/anthropic_test.go +++ b/llm/anthropic/anthropic_test.go @@ -13,7 +13,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/shoenig/test" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -40,16 +40,16 @@ func TestNewProvider(T *testing.T) { t.Parallel() provider, err := NewProvider(nil, nil, nil, nil) - require.Error(t, err) - require.Nil(t, provider) + must.Error(t, err) + must.Nil(t, provider) }) T.Run("standard", func(t *testing.T) { t.Parallel() provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("with base URL", func(t *testing.T) { @@ -60,8 +60,8 @@ func TestNewProvider(T *testing.T) { BaseURL: "https://custom.example.com", DefaultModel: "claude-sonnet-4", }, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("with timeout", func(t *testing.T) { @@ -71,8 +71,8 @@ func TestNewProvider(T *testing.T) { APIKey: "test-key", Timeout: 5 * time.Second, }, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("with error creating request counter", func(t *testing.T) { @@ -86,8 +86,8 @@ func TestNewProvider(T *testing.T) { } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) - require.Error(t, err) - require.Nil(t, provider) + must.Error(t, err) + must.Nil(t, provider) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -109,8 +109,8 @@ func TestNewProvider(T *testing.T) { } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) - require.Error(t, err) - require.Nil(t, provider) + must.Error(t, err) + must.Nil(t, provider) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -120,7 +120,7 @@ func TestNewProvider(T *testing.T) { noopMP := metrics.NewNoopMetricsProvider() h, histErr := noopMP.NewFloat64Histogram("test") - require.NoError(t, histErr) + must.NoError(t, histErr) mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { @@ -133,8 +133,8 @@ func TestNewProvider(T *testing.T) { } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) - require.Error(t, err) - require.Nil(t, provider) + must.Error(t, err) + must.Nil(t, provider) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) @@ -148,10 +148,10 @@ func TestAnthropicProvider_Completion(T *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/v1/messages", r.URL.Path) - require.Equal(t, http.MethodPost, r.Method) + must.EqOp(t, "/v1/messages", r.URL.Path) + must.EqOp(t, http.MethodPost, r.Method) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(anthropicMessageResponse("Hello from Claude mock!"))) + must.NoError(t, json.NewEncoder(w).Encode(anthropicMessageResponse("Hello from Claude mock!"))) })) t.Cleanup(ts.Close) @@ -159,8 +159,8 @@ func TestAnthropicProvider_Completion(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) ctx := t.Context() result, err := provider.Completion(ctx, llm.CompletionParams{ @@ -169,9 +169,9 @@ func TestAnthropicProvider_Completion(T *testing.T) { {Role: "user", Content: "Hello"}, }, }) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "Hello from Claude mock!", result.Content) + must.NoError(t, err) + must.NotNil(t, result) + must.EqOp(t, "Hello from Claude mock!", result.Content) }) T.Run("uses default model when not specified", func(t *testing.T) { @@ -179,7 +179,7 @@ func TestAnthropicProvider_Completion(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(anthropicMessageResponse("Hi there!"))) + must.NoError(t, json.NewEncoder(w).Encode(anthropicMessageResponse("Hi there!"))) })) t.Cleanup(ts.Close) @@ -188,14 +188,14 @@ func TestAnthropicProvider_Completion(T *testing.T) { BaseURL: ts.URL, DefaultModel: "claude-sonnet-4", }, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := provider.Completion(ctx, llm.CompletionParams{ Messages: []llm.Message{{Role: "user", Content: "Hi"}}, }) - require.NoError(t, err) - require.Equal(t, "Hi there!", result.Content) + must.NoError(t, err) + must.EqOp(t, "Hi there!", result.Content) }) T.Run("with API error", func(t *testing.T) { @@ -211,14 +211,14 @@ func TestAnthropicProvider_Completion(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL, }, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := provider.Completion(ctx, llm.CompletionParams{ Model: "claude-sonnet-4-20250514", Messages: []llm.Message{{Role: "user", Content: "Hi"}}, }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) } diff --git a/llm/anthropic/config_test.go b/llm/anthropic/config_test.go index a5f85cf..d0ecc3f 100644 --- a/llm/anthropic/config_test.go +++ b/llm/anthropic/config_test.go @@ -3,7 +3,7 @@ package anthropic import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -16,7 +16,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { APIKey: "test-key", } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing API key", func(t *testing.T) { @@ -24,6 +24,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/llm/config/config_test.go b/llm/config/config_test.go index 82e886d..5a1baef 100644 --- a/llm/config/config_test.go +++ b/llm/config/config_test.go @@ -12,8 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -31,7 +30,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("anthropic provider", func(t *testing.T) { @@ -45,7 +44,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("empty provider is valid", func(t *testing.T) { @@ -54,7 +53,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("unknown provider is invalid", func(t *testing.T) { @@ -65,7 +64,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "nonsense", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("openai provider missing config", func(t *testing.T) { @@ -76,7 +75,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderOpenAI, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("anthropic provider missing config", func(t *testing.T) { @@ -87,7 +86,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderAnthropic, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -101,8 +100,8 @@ func TestConfig_ProvideLLMProvider(T *testing.T) { cfg := &Config{Provider: ""} provider, err := cfg.ProvideLLMProvider(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("unknown provider falls back to noop", func(t *testing.T) { @@ -112,8 +111,8 @@ func TestConfig_ProvideLLMProvider(T *testing.T) { cfg := &Config{Provider: "unknown"} provider, err := cfg.ProvideLLMProvider(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("openai provider", func(t *testing.T) { @@ -128,8 +127,8 @@ func TestConfig_ProvideLLMProvider(T *testing.T) { } provider, err := cfg.ProvideLLMProvider(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("anthropic provider", func(t *testing.T) { @@ -144,8 +143,8 @@ func TestConfig_ProvideLLMProvider(T *testing.T) { } provider, err := cfg.ProvideLLMProvider(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("openai provider with metrics error", func(t *testing.T) { @@ -166,8 +165,8 @@ func TestConfig_ProvideLLMProvider(T *testing.T) { } provider, err := cfg.ProvideLLMProvider(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp) - assert.Nil(t, provider) - assert.Error(t, err) + test.Nil(t, provider) + test.Error(t, err) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -190,8 +189,8 @@ func TestConfig_ProvideLLMProvider(T *testing.T) { } provider, err := cfg.ProvideLLMProvider(ctx, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp) - assert.Nil(t, provider) - assert.Error(t, err) + test.Nil(t, provider) + test.Error(t, err) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -206,7 +205,7 @@ func TestProvideLLMProvider(T *testing.T) { cfg := &Config{} provider, err := ProvideLLMProvider(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - require.NoError(t, err) - assert.NotNil(t, provider) + must.NoError(t, err) + test.NotNil(t, provider) }) } diff --git a/llm/config/do_test.go b/llm/config/do_test.go index 3b0d363..45b0d93 100644 --- a/llm/config/do_test.go +++ b/llm/config/do_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterLLMProvider(T *testing.T) { @@ -28,7 +28,7 @@ func TestRegisterLLMProvider(T *testing.T) { RegisterLLMProvider(i) provider, err := do.Invoke[llm.Provider](i) - require.NoError(t, err) - assert.NotNil(t, provider) + must.NoError(t, err) + test.NotNil(t, provider) }) } diff --git a/llm/llm_test.go b/llm/llm_test.go index 32bd9e7..96ff68a 100644 --- a/llm/llm_test.go +++ b/llm/llm_test.go @@ -3,7 +3,7 @@ package llm import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestNoopProvider_Completion(T *testing.T) { @@ -22,7 +22,7 @@ func TestNoopProvider_Completion(T *testing.T) { }, }) - assert.NoError(t, err) - assert.NotNil(t, result) + test.NoError(t, err) + test.NotNil(t, result) }) } diff --git a/llm/openai/config_test.go b/llm/openai/config_test.go index af71048..ca91e4b 100644 --- a/llm/openai/config_test.go +++ b/llm/openai/config_test.go @@ -3,7 +3,7 @@ package openai import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -16,7 +16,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { APIKey: "test-key", } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing API key", func(t *testing.T) { @@ -24,6 +24,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/llm/openai/openai_test.go b/llm/openai/openai_test.go index 1050768..3037185 100644 --- a/llm/openai/openai_test.go +++ b/llm/openai/openai_test.go @@ -13,7 +13,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/shoenig/test" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -24,16 +24,16 @@ func TestNewProvider(T *testing.T) { t.Parallel() provider, err := NewProvider(nil, nil, nil, nil) - require.Error(t, err) - require.Nil(t, provider) + must.Error(t, err) + must.Nil(t, provider) }) T.Run("standard", func(t *testing.T) { t.Parallel() provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("with base URL and timeout", func(t *testing.T) { @@ -44,8 +44,8 @@ func TestNewProvider(T *testing.T) { BaseURL: "https://custom.example.com/v1", DefaultModel: "gpt-4o", }, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("with timeout", func(t *testing.T) { @@ -55,8 +55,8 @@ func TestNewProvider(T *testing.T) { APIKey: "test-key", Timeout: 5 * time.Second, }, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) }) T.Run("with error creating request counter", func(t *testing.T) { @@ -70,8 +70,8 @@ func TestNewProvider(T *testing.T) { } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) - require.Error(t, err) - require.Nil(t, provider) + must.Error(t, err) + must.Nil(t, provider) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -93,8 +93,8 @@ func TestNewProvider(T *testing.T) { } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) - require.Error(t, err) - require.Nil(t, provider) + must.Error(t, err) + must.Nil(t, provider) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -104,7 +104,7 @@ func TestNewProvider(T *testing.T) { noopMP := metrics.NewNoopMetricsProvider() h, histErr := noopMP.NewFloat64Histogram("test") - require.NoError(t, histErr) + must.NoError(t, histErr) mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { @@ -117,8 +117,8 @@ func TestNewProvider(T *testing.T) { } provider, err := NewProvider(&Config{APIKey: "test-key"}, nil, nil, mp) - require.Error(t, err) - require.Nil(t, provider) + must.Error(t, err) + must.Nil(t, provider) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) @@ -154,10 +154,10 @@ func TestOpenAIProvider_Completion(T *testing.T) { t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, "/v1/chat/completions", r.URL.Path) - require.Equal(t, http.MethodPost, r.Method) + must.EqOp(t, "/v1/chat/completions", r.URL.Path) + must.EqOp(t, http.MethodPost, r.Method) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(openAIChatCompletion)) + must.NoError(t, json.NewEncoder(w).Encode(openAIChatCompletion)) })) t.Cleanup(ts.Close) @@ -165,8 +165,8 @@ func TestOpenAIProvider_Completion(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL + "/v1", }, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) ctx := t.Context() result, err := provider.Completion(ctx, llm.CompletionParams{ @@ -175,9 +175,9 @@ func TestOpenAIProvider_Completion(T *testing.T) { {Role: "user", Content: "Hello"}, }, }) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "Hello from mock!", result.Content) + must.NoError(t, err) + must.NotNil(t, result) + must.EqOp(t, "Hello from mock!", result.Content) }) T.Run("uses default model when not specified", func(t *testing.T) { @@ -185,7 +185,7 @@ func TestOpenAIProvider_Completion(T *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(openAIChatCompletion)) + must.NoError(t, json.NewEncoder(w).Encode(openAIChatCompletion)) })) t.Cleanup(ts.Close) @@ -194,14 +194,14 @@ func TestOpenAIProvider_Completion(T *testing.T) { BaseURL: ts.URL + "/v1", DefaultModel: "gpt-4o", }, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := provider.Completion(ctx, llm.CompletionParams{ Messages: []llm.Message{{Role: "user", Content: "Hi"}}, }) - require.NoError(t, err) - require.Equal(t, "Hello from mock!", result.Content) + must.NoError(t, err) + must.EqOp(t, "Hello from mock!", result.Content) }) T.Run("with API error", func(t *testing.T) { @@ -217,14 +217,14 @@ func TestOpenAIProvider_Completion(T *testing.T) { APIKey: "test-key", BaseURL: ts.URL + "/v1", }, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) ctx := t.Context() result, err := provider.Completion(ctx, llm.CompletionParams{ Model: "gpt-4o-mini", Messages: []llm.Message{{Role: "user", Content: "Hi"}}, }) - require.Error(t, err) - require.Nil(t, result) + must.Error(t, err) + must.Nil(t, result) }) } diff --git a/messagequeue/config/config_test.go b/messagequeue/config/config_test.go index 0fed8a3..d10d358 100644 --- a/messagequeue/config/config_test.go +++ b/messagequeue/config/config_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func Test_cleanString(T *testing.T) { @@ -19,7 +19,7 @@ func Test_cleanString(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotEmpty(t, cleanString(t.Name())) + test.NotEq(t, "", cleanString(t.Name())) }) } @@ -38,7 +38,7 @@ func TestQueuesConfig_ValidateWithContext(T *testing.T) { WebhookExecutionRequestsTopicName: "webhook-execution-requests", } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing fields", func(t *testing.T) { @@ -46,7 +46,7 @@ func TestQueuesConfig_ValidateWithContext(T *testing.T) { cfg := &QueuesConfig{} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } @@ -57,8 +57,8 @@ func TestProvideConsumerProvider(T *testing.T) { t.Parallel() p, err := ProvideConsumerProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, nil) - assert.Nil(t, p) - assert.ErrorIs(t, err, ErrNilConfig) + test.Nil(t, p) + test.ErrorIs(t, err, ErrNilConfig) }) T.Run("with redis provider", func(t *testing.T) { @@ -71,8 +71,8 @@ func TestProvideConsumerProvider(T *testing.T) { } p, err := ProvideConsumerProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - assert.NoError(t, err) - assert.NotNil(t, p) + test.NoError(t, err) + test.NotNil(t, p) }) T.Run("with SQS provider", func(t *testing.T) { @@ -86,8 +86,8 @@ func TestProvideConsumerProvider(T *testing.T) { } p, err := ProvideConsumerProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - assert.NoError(t, err) - assert.NotNil(t, p) + test.NoError(t, err) + test.NotNil(t, p) }) T.Run("with kafka provider", func(t *testing.T) { @@ -101,8 +101,8 @@ func TestProvideConsumerProvider(T *testing.T) { } p, err := ProvideConsumerProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - assert.NoError(t, err) - assert.NotNil(t, p) + test.NoError(t, err) + test.NotNil(t, p) }) T.Run("with pubsub provider and empty project ID", func(t *testing.T) { @@ -116,16 +116,16 @@ func TestProvideConsumerProvider(T *testing.T) { } p, err := ProvideConsumerProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - assert.Nil(t, p) - assert.Error(t, err) + test.Nil(t, p) + test.Error(t, err) }) T.Run("with unknown provider falls back to noop", func(t *testing.T) { t.Parallel() p, err := ProvideConsumerProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, &Config{}) - assert.NoError(t, err) - assert.NotNil(t, p) + test.NoError(t, err) + test.NotNil(t, p) }) } @@ -142,8 +142,8 @@ func TestProvideConsumerProvider_PubSubEmulator(t *testing.T) { } p, err := ProvideConsumerProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - require.NoError(t, err) - assert.NotNil(t, p) + must.NoError(t, err) + test.NotNil(t, p) } func TestProvidePublisherProvider(T *testing.T) { @@ -153,8 +153,8 @@ func TestProvidePublisherProvider(T *testing.T) { t.Parallel() p, err := ProvidePublisherProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, nil) - assert.Nil(t, p) - assert.ErrorIs(t, err, ErrNilConfig) + test.Nil(t, p) + test.ErrorIs(t, err, ErrNilConfig) }) T.Run("with redis provider", func(t *testing.T) { @@ -167,8 +167,8 @@ func TestProvidePublisherProvider(T *testing.T) { } p, err := ProvidePublisherProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - assert.NoError(t, err) - assert.NotNil(t, p) + test.NoError(t, err) + test.NotNil(t, p) }) T.Run("with SQS provider", func(t *testing.T) { @@ -182,8 +182,8 @@ func TestProvidePublisherProvider(T *testing.T) { } p, err := ProvidePublisherProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - assert.NoError(t, err) - assert.NotNil(t, p) + test.NoError(t, err) + test.NotNil(t, p) }) T.Run("with kafka provider", func(t *testing.T) { @@ -197,8 +197,8 @@ func TestProvidePublisherProvider(T *testing.T) { } p, err := ProvidePublisherProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - assert.NoError(t, err) - assert.NotNil(t, p) + test.NoError(t, err) + test.NotNil(t, p) }) T.Run("with pubsub provider and empty project ID", func(t *testing.T) { @@ -212,16 +212,16 @@ func TestProvidePublisherProvider(T *testing.T) { } p, err := ProvidePublisherProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - assert.Nil(t, p) - assert.Error(t, err) + test.Nil(t, p) + test.Error(t, err) }) T.Run("with unknown provider falls back to noop", func(t *testing.T) { t.Parallel() p, err := ProvidePublisherProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, &Config{}) - assert.NoError(t, err) - assert.NotNil(t, p) + test.NoError(t, err) + test.NotNil(t, p) }) } @@ -238,6 +238,6 @@ func TestProvidePublisherProvider_PubSubEmulator(t *testing.T) { } p, err := ProvidePublisherProvider(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, cfg) - require.NoError(t, err) - assert.NotNil(t, p) + must.NoError(t, err) + test.NotNil(t, p) } diff --git a/messagequeue/config/do_test.go b/messagequeue/config/do_test.go index c8894f5..a4ca7c0 100644 --- a/messagequeue/config/do_test.go +++ b/messagequeue/config/do_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterMessageQueue(T *testing.T) { @@ -29,11 +29,11 @@ func TestRegisterMessageQueue(T *testing.T) { RegisterMessageQueue(i) consumer, err := do.Invoke[messagequeue.ConsumerProvider](i) - require.NoError(t, err) - assert.NotNil(t, consumer) + must.NoError(t, err) + test.NotNil(t, consumer) publisher, err := do.Invoke[messagequeue.PublisherProvider](i) - require.NoError(t, err) - assert.NotNil(t, publisher) + must.NoError(t, err) + test.NotNil(t, publisher) }) } diff --git a/messagequeue/kafka/config_test.go b/messagequeue/kafka/config_test.go index b0bf069..1a06af4 100644 --- a/messagequeue/kafka/config_test.go +++ b/messagequeue/kafka/config_test.go @@ -3,7 +3,7 @@ package kafka import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -19,7 +19,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { GroupID: "test-group", } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with empty brokers", func(t *testing.T) { @@ -32,7 +32,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { GroupID: "test-group", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with nil brokers", func(t *testing.T) { @@ -44,6 +44,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { GroupID: "test-group", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/messagequeue/kafka/consumer_test.go b/messagequeue/kafka/consumer_test.go index 4294182..a01e6d8 100644 --- a/messagequeue/kafka/consumer_test.go +++ b/messagequeue/kafka/consumer_test.go @@ -12,8 +12,7 @@ import ( "github.com/segmentio/kafka-go" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -163,7 +162,7 @@ func Test_kafkaConsumer_Consume(T *testing.T) { select { case receivedErr := <-errs: - assert.Equal(t, fetchErr, receivedErr) + test.ErrorIs(t, receivedErr, fetchErr) default: t.Error("expected an error on the errors channel") } @@ -215,8 +214,8 @@ func Test_kafkaConsumer_Consume(T *testing.T) { return kafka.Message{}, context.Canceled }, commitMessagesFunc: func(_ context.Context, msgs ...kafka.Message) error { - require.Len(t, msgs, 1) - assert.Equal(t, msg, msgs[0]) + must.SliceLen(t, 1, msgs) + test.Eq(t, msg, msgs[0]) return nil }, } @@ -229,7 +228,7 @@ func Test_kafkaConsumer_Consume(T *testing.T) { consumedCounter: metrics.Int64CounterForTest(t, t.Name()), handlerFunc: func(_ context.Context, data []byte) error { handlerCalled = true - assert.Equal(t, []byte("test-message"), data) + test.Eq(t, []byte("test-message"), data) cancel() return nil }, @@ -239,8 +238,8 @@ func Test_kafkaConsumer_Consume(T *testing.T) { errs := make(chan error, 10) c.Consume(ctx, stopChan, errs) - assert.True(t, handlerCalled) - assert.Equal(t, 1, reader.commitCalls) + test.True(t, handlerCalled) + test.EqOp(t, 1, reader.commitCalls) }) T.Run("with handler error", func(t *testing.T) { @@ -279,8 +278,8 @@ func Test_kafkaConsumer_Consume(T *testing.T) { c.Consume(ctx, stopChan, errs) receivedErr := <-errs - assert.Error(t, receivedErr) - assert.Equal(t, handlerErr, receivedErr) + test.Error(t, receivedErr) + test.ErrorIs(t, receivedErr, handlerErr) }) T.Run("with handler error and nil errors channel", func(t *testing.T) { @@ -353,7 +352,7 @@ func Test_kafkaConsumer_Consume(T *testing.T) { errs := make(chan error, 10) c.Consume(ctx, stopChan, errs) - assert.Equal(t, 1, reader.commitCalls) + test.EqOp(t, 1, reader.commitCalls) }) } @@ -374,7 +373,7 @@ func TestProvideKafkaConsumerProvider(T *testing.T) { nil, cfg, ) - assert.NotNil(t, actual) + test.NotNil(t, actual) }) } @@ -397,13 +396,13 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) hf := func(context.Context, []byte) error { return nil } actual, err := provider.ProvideConsumer(ctx, t.Name(), hf) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) T.Run("with empty topic", func(t *testing.T) { @@ -422,12 +421,12 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvideConsumer(ctx, "", nil) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrEmptyInputProvided) - assert.Nil(t, actual) + test.Error(t, err) + test.ErrorIs(t, err, ErrEmptyInputProvided) + test.Nil(t, actual) }) T.Run("with error creating consumed counter", func(t *testing.T) { @@ -452,13 +451,13 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { mp, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) hf := func(context.Context, []byte) error { return nil } actual, err := provider.ProvideConsumer(ctx, t.Name(), hf) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -479,18 +478,18 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) hf := func(context.Context, []byte) error { return nil } first, err := provider.ProvideConsumer(ctx, t.Name(), hf) - assert.NoError(t, err) - assert.NotNil(t, first) + test.NoError(t, err) + test.NotNil(t, first) second, err := provider.ProvideConsumer(ctx, t.Name(), hf) - assert.NoError(t, err) - assert.NotNil(t, second) + test.NoError(t, err) + test.NotNil(t, second) - assert.Equal(t, first, second) + test.True(t, first == second) }) } diff --git a/messagequeue/kafka/publisher_test.go b/messagequeue/kafka/publisher_test.go index 0af4214..ca44c52 100644 --- a/messagequeue/kafka/publisher_test.go +++ b/messagequeue/kafka/publisher_test.go @@ -17,8 +17,7 @@ import ( "github.com/segmentio/kafka-go" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -53,13 +52,13 @@ func buildTestPublisher(t *testing.T) (*kafkaPublisher, *mockKafkaWriter) { mp := metrics.NewNoopMetricsProvider() publishedCounter, err := mp.NewInt64Counter("test_published") - require.NoError(t, err) + must.NoError(t, err) publishErrCounter, err := mp.NewInt64Counter("test_publish_errors") - require.NoError(t, err) + must.NoError(t, err) latencyHist, err := mp.NewFloat64Histogram("test_publish_latency_ms") - require.NoError(t, err) + must.NoError(t, err) pub := &kafkaPublisher{ writer: writer, @@ -85,7 +84,7 @@ func Test_kafkaPublisher_Stop(T *testing.T) { pub.Stop() - assert.Equal(t, 1, writer.closeCalls) + test.EqOp(t, 1, writer.closeCalls) }) T.Run("with close error", func(t *testing.T) { @@ -96,7 +95,7 @@ func Test_kafkaPublisher_Stop(T *testing.T) { pub.Stop() - assert.Equal(t, 1, writer.closeCalls) + test.EqOp(t, 1, writer.closeCalls) }) } @@ -118,9 +117,9 @@ func Test_kafkaPublisher_Publish(T *testing.T) { writer.writeMessagesFunc = func(_ context.Context, _ ...kafka.Message) error { return nil } err := pub.Publish(ctx, inputData) - assert.NoError(t, err) + test.NoError(t, err) - assert.Equal(t, 1, writer.writeCalls) + test.EqOp(t, 1, writer.writeCalls) }) T.Run("with encoding error", func(t *testing.T) { @@ -136,7 +135,7 @@ func Test_kafkaPublisher_Publish(T *testing.T) { } err := pub.Publish(ctx, inputData) - assert.Error(t, err) + test.Error(t, err) }) T.Run("with write error", func(t *testing.T) { @@ -154,9 +153,9 @@ func Test_kafkaPublisher_Publish(T *testing.T) { writer.writeMessagesFunc = func(_ context.Context, _ ...kafka.Message) error { return errors.New("write failed") } err := pub.Publish(ctx, inputData) - assert.Error(t, err) + test.Error(t, err) - assert.Equal(t, 1, writer.writeCalls) + test.EqOp(t, 1, writer.writeCalls) }) T.Run("with mock encoder error", func(t *testing.T) { @@ -173,7 +172,7 @@ func Test_kafkaPublisher_Publish(T *testing.T) { pub.encoder = enc err := pub.Publish(ctx, "something") - assert.Error(t, err) + test.Error(t, err) test.SliceLen(t, 1, enc.EncodeCalls()) }) @@ -198,7 +197,7 @@ func Test_kafkaPublisher_PublishAsync(T *testing.T) { pub.PublishAsync(ctx, inputData) - assert.Equal(t, 1, writer.writeCalls) + test.EqOp(t, 1, writer.writeCalls) }) T.Run("with encoding error", func(t *testing.T) { @@ -232,7 +231,7 @@ func Test_kafkaPublisher_PublishAsync(T *testing.T) { pub.PublishAsync(ctx, inputData) - assert.Equal(t, 1, writer.writeCalls) + test.EqOp(t, 1, writer.writeCalls) }) } @@ -253,7 +252,7 @@ func TestProvideKafkaPublisherProvider(T *testing.T) { nil, cfg, ) - assert.NotNil(t, actual) + test.NotNil(t, actual) }) } @@ -276,11 +275,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) T.Run("with empty topic", func(t *testing.T) { @@ -299,12 +298,12 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, "") - assert.Error(t, err) - assert.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) - assert.Nil(t, actual) + test.Error(t, err) + test.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) + test.Nil(t, actual) }) T.Run("with cache hit", func(t *testing.T) { @@ -323,17 +322,17 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) first, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NoError(t, err) - assert.NotNil(t, first) + test.NoError(t, err) + test.NotNil(t, first) second, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NoError(t, err) - assert.NotNil(t, second) + test.NoError(t, err) + test.NotNil(t, second) - assert.Equal(t, first, second) + test.True(t, first == second) }) T.Run("with error creating published counter", func(t *testing.T) { @@ -358,11 +357,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { mp, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, t.Name()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -391,11 +390,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { mp, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, t.Name()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -425,11 +424,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { mp, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, t.Name()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) @@ -454,10 +453,10 @@ func Test_publisherProvider_Ping(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) err := provider.Ping(ctx) - assert.Error(t, err) + test.Error(t, err) }) } @@ -480,13 +479,13 @@ func Test_publisherProvider_Close(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) _, err := provider.ProvidePublisher(ctx, t.Name()) - require.NoError(t, err) + must.NoError(t, err) pp, ok := provider.(*publisherProvider) - require.True(t, ok) + must.True(t, ok) // Replace cached publisher with one using a mock writer so Close doesn't hit real Kafka mw := &mockKafkaWriter{ @@ -509,7 +508,7 @@ func Test_publisherProvider_Close(T *testing.T) { provider.Close() - assert.Equal(t, 1, mw.closeCalls) + test.EqOp(t, 1, mw.closeCalls) }) T.Run("with empty cache", func(t *testing.T) { @@ -526,7 +525,7 @@ func Test_publisherProvider_Close(T *testing.T) { nil, cfg, ) - require.NotNil(t, provider) + must.NotNil(t, provider) provider.Close() }) diff --git a/messagequeue/noop/noop_test.go b/messagequeue/noop/noop_test.go index a2238b3..81a915b 100644 --- a/messagequeue/noop/noop_test.go +++ b/messagequeue/noop/noop_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestPublisherProvider_ProvidePublisher(T *testing.T) { @@ -16,8 +16,8 @@ func TestPublisherProvider_ProvidePublisher(T *testing.T) { p := NewPublisherProvider() pub, err := p.ProvidePublisher(context.Background(), "topic") - require.NoError(t, err) - assert.NotNil(t, pub) + must.NoError(t, err) + test.NotNil(t, pub) }) } @@ -40,7 +40,7 @@ func TestPublisherProvider_Ping(T *testing.T) { p := NewPublisherProvider() err := p.Ping(context.Background()) - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -52,7 +52,7 @@ func TestPublisher_Publish(T *testing.T) { p := NewPublisher() err := p.Publish(context.Background(), "data") - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -86,8 +86,8 @@ func TestConsumerProvider_ProvideConsumer(T *testing.T) { p := NewConsumerProvider() c, err := p.ProvideConsumer(context.Background(), "topic", func(_ context.Context, _ []byte) error { return nil }) - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) }) } diff --git a/messagequeue/publishers_test.go b/messagequeue/publishers_test.go index fd6e956..6a30d65 100644 --- a/messagequeue/publishers_test.go +++ b/messagequeue/publishers_test.go @@ -3,7 +3,7 @@ package messagequeue import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestErrEmptyTopicName(T *testing.T) { @@ -12,7 +12,7 @@ func TestErrEmptyTopicName(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, ErrEmptyTopicName) - assert.Error(t, ErrEmptyTopicName) + test.NotNil(t, ErrEmptyTopicName) + test.Error(t, ErrEmptyTopicName) }) } diff --git a/messagequeue/pubsub/config_test.go b/messagequeue/pubsub/config_test.go index 9b7da3b..380fe59 100644 --- a/messagequeue/pubsub/config_test.go +++ b/messagequeue/pubsub/config_test.go @@ -3,7 +3,7 @@ package pubsub import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -15,6 +15,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/messagequeue/pubsub/consumer_test.go b/messagequeue/pubsub/consumer_test.go index 81357db..c2ad974 100644 --- a/messagequeue/pubsub/consumer_test.go +++ b/messagequeue/pubsub/consumer_test.go @@ -20,8 +20,8 @@ import ( "cloud.google.com/go/pubsub/v2" "cloud.google.com/go/pubsub/v2/apiv1/pubsubpb" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" tcpubsub "github.com/testcontainers/testcontainers-go/modules/gcloud/pubsub" "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" @@ -48,7 +48,7 @@ func buildPubSubTestInfra(t *testing.T) *pubsubTestInfra { ctx := t.Context() randomID, err := random.GenerateHexEncodedString(ctx, 8) - require.NoError(t, err) + must.NoError(t, err) projectID := "project-" + randomID pubsubContainer, err := tcpubsub.Run( @@ -56,16 +56,16 @@ func buildPubSubTestInfra(t *testing.T) *pubsubTestInfra { "gcr.io/google.com/cloudsdktool/cloud-sdk:emulators", tcpubsub.WithProjectID(projectID), ) - require.NoError(t, err) - require.NotNil(t, pubsubContainer) + must.NoError(t, err) + must.NotNil(t, pubsubContainer) conn, err := grpc.NewClient(pubsubContainer.URI(), grpc.WithTransportCredentials(insecure.NewCredentials())) - require.NoError(t, err) - require.NotNil(t, conn) + must.NoError(t, err) + must.NotNil(t, conn) client, err := pubsub.NewClient(ctx, projectID, option.WithGRPCConn(conn)) - require.NoError(t, err) - require.NotNil(t, client) + must.NoError(t, err) + must.NotNil(t, client) return &pubsubTestInfra{ client: client, @@ -86,15 +86,15 @@ func (i *pubsubTestInfra) newTopic(t *testing.T) string { topicName := fmt.Sprintf("projects/%s/topics/topic-%s", i.projectID, identifiers.New()) pubSubTopic, err := i.client.TopicAdminClient.CreateTopic(ctx, &pubsubpb.Topic{Name: topicName}) - require.NoError(t, err) - require.NotNil(t, pubSubTopic) + must.NoError(t, err) + must.NotNil(t, pubSubTopic) subscription, err := i.client.SubscriptionAdminClient.CreateSubscription(ctx, &pubsubpb.Subscription{ Name: subscriptionNameForTopic(pubSubTopic.GetName()), Topic: pubSubTopic.GetName(), }) - require.NoError(t, err) - require.NotNil(t, subscription) + must.NoError(t, err) + must.NotNil(t, subscription) return pubSubTopic.GetName() } @@ -106,14 +106,14 @@ func TestSubscriptionNameForTopic(T *testing.T) { t.Parallel() result := subscriptionNameForTopic("projects/my-project/topics/my-topic") - assert.Equal(t, "projects/my-project/subscriptions/my-topic", result) + test.EqOp(t, "projects/my-project/subscriptions/my-topic", result) }) T.Run("no match", func(t *testing.T) { t.Parallel() result := subscriptionNameForTopic("some-other-string") - assert.Equal(t, "some-other-string", result) + test.EqOp(t, "some-other-string", result) }) } @@ -127,7 +127,7 @@ func TestBuildPubSubConsumer(T *testing.T) { handler := func(_ context.Context, _ []byte) error { return nil } consumer := buildPubSubConsumer(logger, tracing.NewNoopTracerProvider(), nil, nil, "test-topic", handler) - require.NotNil(t, consumer) + must.NotNil(t, consumer) }) T.Run("panics when NewInt64Counter fails", func(t *testing.T) { @@ -139,7 +139,7 @@ func TestBuildPubSubConsumer(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { buildPubSubConsumer(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) }) @@ -153,7 +153,7 @@ func TestProvidePubSubConsumerProvider(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvidePubSubConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, nil) - require.NotNil(t, provider) + must.NotNil(t, provider) }) } @@ -167,8 +167,8 @@ func TestPubSubConsumerProvider_ProvideConsumer(T *testing.T) { provider := ProvidePubSubConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, nil) consumer, err := provider.ProvideConsumer(t.Context(), "", func(_ context.Context, _ []byte) error { return nil }) - assert.Nil(t, consumer) - assert.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) + test.Nil(t, consumer) + test.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) }) } @@ -194,11 +194,11 @@ func TestPubSub_Container(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvidePubSubPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, infra.client, infra.projectID) - require.NotNil(t, provider) + must.NotNil(t, provider) publisher, err := provider.ProvidePublisher(ctx, topicName) - require.NoError(t, err) - require.NotNil(t, publisher) + must.NoError(t, err) + must.NotNil(t, publisher) inputData := &struct { Name string `json:"name"` @@ -206,7 +206,7 @@ func TestPubSub_Container(T *testing.T) { Name: t.Name(), } - assert.NoError(t, publisher.Publish(ctx, inputData)) + test.NoError(t, publisher.Publish(ctx, inputData)) }) T.Run("consumer provider caches consumers for same topic", func(t *testing.T) { @@ -221,12 +221,12 @@ func TestPubSub_Container(T *testing.T) { handler := func(_ context.Context, _ []byte) error { return nil } c1, err := provider.ProvideConsumer(ctx, topicName, handler) - require.NoError(t, err) - require.NotNil(t, c1) + must.NoError(t, err) + must.NotNil(t, c1) c2, err := provider.ProvideConsumer(ctx, topicName, handler) - require.NoError(t, err) - assert.Equal(t, c1, c2) + must.NoError(t, err) + test.True(t, c1 == c2) }) T.Run("consumer receives published message", func(t *testing.T) { @@ -244,7 +244,7 @@ func TestPubSub_Container(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvidePubSubConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, infra.client) consumer, err := provider.ProvideConsumer(ctx, topicName, handler) - require.NoError(t, err) + must.NoError(t, err) stopChan := make(chan bool, 1) errChan := make(chan error, 1) @@ -255,10 +255,14 @@ func TestPubSub_Container(T *testing.T) { result := publisher.Publish(ctx, &pubsub.Message{Data: []byte(`{"name":"test"}`)}) <-result.Ready() _, err = result.Get(ctx) - require.NoError(t, err) + must.NoError(t, err) // Wait for handler to be called. - assert.Eventually(t, called.Load, 10*time.Second, 100*time.Millisecond) + deadline := time.Now().Add(10 * time.Second) + for !called.Load() && time.Now().Before(deadline) { + time.Sleep(100 * time.Millisecond) + } + test.True(t, called.Load()) stopChan <- true @@ -283,7 +287,7 @@ func TestPubSub_Container(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvidePubSubConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, infra.client) consumer, err := provider.ProvideConsumer(ctx, topicName, handler) - require.NoError(t, err) + must.NoError(t, err) stopChan := make(chan bool, 1) errChan := make(chan error, 1) @@ -294,12 +298,12 @@ func TestPubSub_Container(T *testing.T) { result := publisher.Publish(ctx, &pubsub.Message{Data: []byte(`{"name":"test"}`)}) <-result.Ready() _, err = result.Get(ctx) - require.NoError(t, err) + must.NoError(t, err) // Wait for the error to appear. select { case receivedErr := <-errChan: - assert.Equal(t, expectedErr, receivedErr) + test.ErrorIs(t, receivedErr, expectedErr) case <-time.After(10 * time.Second): t.Fatal("timed out waiting for handler error") } @@ -318,7 +322,7 @@ func TestPubSub_Container(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvidePubSubConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, infra.client) consumer, err := provider.ProvideConsumer(ctx, topicName, handler) - require.NoError(t, err) + must.NoError(t, err) stopChan := make(chan bool, 1) errChan := make(chan error, 1) @@ -354,7 +358,7 @@ func TestPubSub_Container(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvidePubSubConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, infra.client) consumer, err := provider.ProvideConsumer(ctx, topicName, handler) - require.NoError(t, err) + must.NoError(t, err) errChan := make(chan error, 1) @@ -370,8 +374,12 @@ func TestPubSub_Container(T *testing.T) { result := publisher.Publish(ctx, &pubsub.Message{Data: []byte(`{"name":"test"}`)}) <-result.Ready() _, err = result.Get(ctx) - require.NoError(t, err) + must.NoError(t, err) - assert.Eventually(t, called.Load, 10*time.Second, 100*time.Millisecond) + deadline := time.Now().Add(10 * time.Second) + for !called.Load() && time.Now().Before(deadline) { + time.Sleep(100 * time.Millisecond) + } + test.True(t, called.Load()) }) } diff --git a/messagequeue/pubsub/publisher_test.go b/messagequeue/pubsub/publisher_test.go index 1331421..1dabcef 100644 --- a/messagequeue/pubsub/publisher_test.go +++ b/messagequeue/pubsub/publisher_test.go @@ -10,8 +10,8 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) @@ -23,7 +23,7 @@ func TestBuildPubSubPublisher(T *testing.T) { t.Parallel() publisher := buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), nil, "test-topic") - require.NotNil(t, publisher) + must.NotNil(t, publisher) }) T.Run("panics when first NewInt64Counter fails", func(t *testing.T) { @@ -39,7 +39,7 @@ func TestBuildPubSubPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) }) @@ -60,7 +60,7 @@ func TestBuildPubSubPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) }) @@ -77,7 +77,7 @@ func TestBuildPubSubPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) }) @@ -90,7 +90,7 @@ func TestProvidePubSubPublisherProvider(T *testing.T) { t.Parallel() provider := ProvidePubSubPublisherProvider(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, nil, "test-project") - require.NotNil(t, provider) + must.NotNil(t, provider) }) } @@ -101,7 +101,7 @@ func TestPublisherProvider_Ping(T *testing.T) { t.Parallel() p := &publisherProvider{} - assert.NoError(t, p.Ping(t.Context())) + test.NoError(t, p.Ping(t.Context())) }) } @@ -113,7 +113,7 @@ func TestPublisherProvider_qualifyTopicName(T *testing.T) { p := &publisherProvider{projectID: "my-project"} result := p.qualifyTopicName("projects/my-project/topics/my-topic") - assert.Equal(t, "projects/my-project/topics/my-topic", result) + test.EqOp(t, "projects/my-project/topics/my-topic", result) }) T.Run("unqualified", func(t *testing.T) { @@ -121,7 +121,7 @@ func TestPublisherProvider_qualifyTopicName(T *testing.T) { p := &publisherProvider{projectID: "my-project"} result := p.qualifyTopicName("my-topic") - assert.Equal(t, "projects/my-project/topics/my-topic", result) + test.EqOp(t, "projects/my-project/topics/my-topic", result) }) } @@ -134,7 +134,7 @@ func TestPublisherProvider_ProvidePublisher(T *testing.T) { provider := ProvidePubSubPublisherProvider(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, nil, "test-project") pub, err := provider.ProvidePublisher(t.Context(), "") - assert.Nil(t, pub) - assert.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) + test.Nil(t, pub) + test.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) }) } diff --git a/messagequeue/redis/config_test.go b/messagequeue/redis/config_test.go index c2683ed..91ee4c4 100644 --- a/messagequeue/redis/config_test.go +++ b/messagequeue/redis/config_test.go @@ -3,7 +3,7 @@ package redis import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -19,6 +19,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { QueueAddresses: []string{t.Name()}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/messagequeue/redis/consumer_test.go b/messagequeue/redis/consumer_test.go index c25567b..7c51021 100644 --- a/messagequeue/redis/consumer_test.go +++ b/messagequeue/redis/consumer_test.go @@ -12,8 +12,8 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) @@ -30,7 +30,7 @@ func buildRedisBackedConsumer(t *testing.T, cfg *Config, topic string, handlerFu ) consumer, err := provider.ProvideConsumer(t.Context(), topic, handlerFunc) - require.NoError(t, err) + must.NoError(t, err) return consumer } @@ -49,7 +49,7 @@ func Test_redisConsumer_Consume(T *testing.T) { } defer func() { if containerShutdown != nil { - assert.NoError(t, containerShutdown(ctx)) + test.NoError(t, containerShutdown(ctx)) } }() @@ -58,14 +58,14 @@ func Test_redisConsumer_Consume(T *testing.T) { } consumer := buildRedisBackedConsumer(t, cfg, t.Name(), hf) - require.NotNil(t, consumer) + must.NotNil(t, consumer) stopChan := make(chan bool) errorsChan := make(chan error) go consumer.Consume(ctx, stopChan, errorsChan) publisher := buildRedisBackedPublisher(t, cfg, t.Name()) - require.NoError(t, publisher.Publish(ctx, []byte("blah"))) + must.NoError(t, publisher.Publish(ctx, []byte("blah"))) <-time.After(time.Second) stopChan <- true @@ -82,7 +82,7 @@ func Test_redisConsumer_Consume(T *testing.T) { } defer func() { if containerShutdown != nil { - assert.NoError(t, containerShutdown(ctx)) + test.NoError(t, containerShutdown(ctx)) } }() @@ -92,18 +92,18 @@ func Test_redisConsumer_Consume(T *testing.T) { } consumer := buildRedisBackedConsumer(t, cfg, t.Name(), hf) - require.NotNil(t, consumer) + must.NotNil(t, consumer) stopChan := make(chan bool) errorsChan := make(chan error) go consumer.Consume(ctx, stopChan, errorsChan) publisher := buildRedisBackedPublisher(t, cfg, t.Name()) - require.NoError(t, publisher.Publish(ctx, []byte("blah"))) + must.NoError(t, publisher.Publish(ctx, []byte("blah"))) receivedErr := <-errorsChan - assert.Error(t, receivedErr) - assert.Equal(t, anticipatedError, receivedErr) + test.Error(t, receivedErr) + test.ErrorIs(t, receivedErr, anticipatedError) stopChan <- true }) @@ -121,13 +121,13 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { } conPro := ProvideRedisConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, conPro) + must.NotNil(t, conPro) ctx := t.Context() actual, err := conPro.ProvideConsumer(ctx, t.Name(), nil) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) T.Run("hitting cache", func(t *testing.T) { @@ -139,17 +139,17 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { } conPro := ProvideRedisConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, conPro) + must.NotNil(t, conPro) ctx := t.Context() actual, err := conPro.ProvideConsumer(ctx, t.Name(), nil) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) actual, err = conPro.ProvideConsumer(ctx, t.Name(), nil) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) T.Run("with empty topic", func(t *testing.T) { @@ -161,11 +161,11 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { } conPro := ProvideRedisConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, conPro) + must.NotNil(t, conPro) actual, err := conPro.ProvideConsumer(t.Context(), "", nil) - assert.Nil(t, actual) - assert.ErrorIs(t, err, ErrEmptyInputProvided) + test.Nil(t, actual) + test.ErrorIs(t, err, ErrEmptyInputProvided) }) } @@ -181,7 +181,7 @@ func Test_provideRedisConsumer(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { provideRedisConsumer(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) }) diff --git a/messagequeue/redis/publisher_test.go b/messagequeue/redis/publisher_test.go index a02443d..5608323 100644 --- a/messagequeue/redis/publisher_test.go +++ b/messagequeue/redis/publisher_test.go @@ -14,8 +14,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/go-redis/redis/v8" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) @@ -58,7 +58,7 @@ func buildRedisBackedPublisher(t *testing.T, cfg *Config, topic string) messageq ) publisher, err := provider.ProvidePublisher(ctx, topic) - require.NoError(t, err) + must.NoError(t, err) return publisher } @@ -76,14 +76,14 @@ func Test_redisPublisher_Publish(T *testing.T) { QueueAddresses: []string{t.Name()}, } provider := ProvideRedisPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*redisPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name string `json:"name"` @@ -98,11 +98,11 @@ func Test_redisPublisher_Publish(T *testing.T) { actual.publisher = mmp err = actual.Publish(ctx, inputData) - assert.NoError(t, err) + test.NoError(t, err) - require.Len(t, mmp.publishArgs, 1) - assert.Equal(t, actual.topic, mmp.publishArgs[0].channel) - assert.Equal(t, fmt.Appendf(nil, `{"name":%q}%s`, t.Name(), string(byte(10))), mmp.publishArgs[0].message) + must.SliceLen(t, 1, mmp.publishArgs) + test.EqOp(t, actual.topic, mmp.publishArgs[0].channel) + test.Eq(t, any(fmt.Appendf(nil, `{"name":%q}%s`, t.Name(), string(byte(10)))), mmp.publishArgs[0].message) }) T.Run("with error encoding value", func(t *testing.T) { @@ -115,14 +115,14 @@ func Test_redisPublisher_Publish(T *testing.T) { QueueAddresses: []string{t.Name()}, } provider := ProvideRedisPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*redisPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name json.Number `json:"name"` @@ -131,7 +131,7 @@ func Test_redisPublisher_Publish(T *testing.T) { } err = actual.Publish(ctx, inputData) - assert.Error(t, err) + test.Error(t, err) }) } @@ -148,14 +148,14 @@ func Test_redisPublisher_PublishAsync(T *testing.T) { QueueAddresses: []string{t.Name()}, } provider := ProvideRedisPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*redisPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name string `json:"name"` @@ -171,9 +171,9 @@ func Test_redisPublisher_PublishAsync(T *testing.T) { actual.PublishAsync(ctx, inputData) - require.Len(t, mmp.publishArgs, 1) - assert.Equal(t, actual.topic, mmp.publishArgs[0].channel) - assert.Equal(t, fmt.Appendf(nil, `{"name":%q}%s`, t.Name(), string(byte(10))), mmp.publishArgs[0].message) + must.SliceLen(t, 1, mmp.publishArgs) + test.EqOp(t, actual.topic, mmp.publishArgs[0].channel) + test.Eq(t, any(fmt.Appendf(nil, `{"name":%q}%s`, t.Name(), string(byte(10)))), mmp.publishArgs[0].message) }) T.Run("with error encoding value", func(t *testing.T) { @@ -186,14 +186,14 @@ func Test_redisPublisher_PublishAsync(T *testing.T) { QueueAddresses: []string{t.Name()}, } provider := ProvideRedisPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*redisPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name json.Number `json:"name"` @@ -217,7 +217,7 @@ func TestProvideRedisPublisherProvider(T *testing.T) { QueueAddresses: []string{t.Name()}, } actual := ProvideRedisPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - assert.NotNil(t, actual) + test.NotNil(t, actual) }) } @@ -234,11 +234,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { QueueAddresses: []string{t.Name()}, } provider := ProvideRedisPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with cache hit", func(t *testing.T) { @@ -251,15 +251,15 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { QueueAddresses: []string{t.Name()}, } provider := ProvideRedisPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) actual, err = provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with empty topic", func(t *testing.T) { @@ -272,11 +272,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { QueueAddresses: []string{t.Name()}, } provider := ProvideRedisPublisherProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, "") - assert.Nil(t, actual) - assert.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) + test.Nil(t, actual) + test.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) }) } @@ -287,7 +287,7 @@ func Test_provideRedisPublisher(T *testing.T) { t.Parallel() publisher := provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, nil, "test-topic") - require.NotNil(t, publisher) + must.NotNil(t, publisher) }) T.Run("panics when first NewInt64Counter fails", func(t *testing.T) { @@ -303,7 +303,7 @@ func Test_provideRedisPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) }) @@ -324,7 +324,7 @@ func Test_provideRedisPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) }) @@ -341,7 +341,7 @@ func Test_provideRedisPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) }) diff --git a/messagequeue/sqs/consumer_test.go b/messagequeue/sqs/consumer_test.go index f8bf957..9cd2522 100644 --- a/messagequeue/sqs/consumer_test.go +++ b/messagequeue/sqs/consumer_test.go @@ -14,8 +14,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) @@ -49,9 +49,9 @@ func Test_sqsConsumer_Consume(T *testing.T) { receiveMessageFunc: func(_ context.Context, in *sqs.ReceiveMessageInput, _ ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { receiveCalls++ if receiveCalls == 1 { - assert.Equal(t, queueURL, aws.ToString(in.QueueUrl)) - assert.Equal(t, int32(maxNumberOfMessages), in.MaxNumberOfMessages) - assert.Equal(t, int32(longPollWaitSeconds), in.WaitTimeSeconds) + test.EqOp(t, queueURL, aws.ToString(in.QueueUrl)) + test.EqOp(t, int32(maxNumberOfMessages), in.MaxNumberOfMessages) + test.EqOp(t, int32(longPollWaitSeconds), in.WaitTimeSeconds) return &sqs.ReceiveMessageOutput{ Messages: []types.Message{ { @@ -64,8 +64,8 @@ func Test_sqsConsumer_Consume(T *testing.T) { return &sqs.ReceiveMessageOutput{Messages: []types.Message{}}, nil }, deleteMessageFunc: func(_ context.Context, in *sqs.DeleteMessageInput, _ ...func(*sqs.Options)) (*sqs.DeleteMessageOutput, error) { - assert.Equal(t, queueURL, aws.ToString(in.QueueUrl)) - assert.Equal(t, "receipt-handle-123", aws.ToString(in.ReceiptHandle)) + test.EqOp(t, queueURL, aws.ToString(in.QueueUrl)) + test.EqOp(t, "receipt-handle-123", aws.ToString(in.ReceiptHandle)) deleteCalled <- struct{}{} return &sqs.DeleteMessageOutput{}, nil }, @@ -87,7 +87,7 @@ func Test_sqsConsumer_Consume(T *testing.T) { <-deleteCalled // wait for DeleteMessage before stopping stopChan <- true - assert.Equal(t, []byte("test-payload"), receivedBody) + test.Eq(t, []byte("test-payload"), receivedBody) }) T.Run("handler error does not delete message", func(t *testing.T) { @@ -127,12 +127,12 @@ func Test_sqsConsumer_Consume(T *testing.T) { go consumer.Consume(t.Context(), stopChan, errs) receivedErr := <-errs - assert.Error(t, receivedErr) - assert.Equal(t, anticipatedErr, receivedErr) + test.Error(t, receivedErr) + test.ErrorIs(t, receivedErr, anticipatedErr) stopChan <- true - assert.Zero(t, mmr.deleteMessageCalls) + test.EqOp(t, 0, mmr.deleteMessageCalls) }) } @@ -147,8 +147,8 @@ func TestProvideSQSConsumerProvider(T *testing.T) { cfg := Config{} actual, err := ProvideSQSConsumerProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil, cfg) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) } @@ -163,12 +163,12 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { cfg := Config{} provider, err := ProvideSQSConsumerProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) actual, err := provider.ProvideConsumer(ctx, "https://sqs.us-east-1.amazonaws.com/123/test", nil) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) T.Run("with cache hit", func(t *testing.T) { @@ -180,17 +180,17 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { topic := "https://sqs.us-east-1.amazonaws.com/123/cached-queue" provider, err := ProvideSQSConsumerProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) actual, err := provider.ProvideConsumer(ctx, topic, nil) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) actual2, err := provider.ProvideConsumer(ctx, topic, nil) - assert.NoError(t, err) - assert.NotNil(t, actual2) - assert.Same(t, actual, actual2) + test.NoError(t, err) + test.NotNil(t, actual2) + test.EqOp(t, actual, actual2) }) T.Run("with empty topic returns error", func(t *testing.T) { @@ -201,13 +201,13 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { cfg := Config{} provider, err := ProvideSQSConsumerProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil, cfg) - require.NoError(t, err) - require.NotNil(t, provider) + must.NoError(t, err) + must.NotNil(t, provider) actual, err := provider.ProvideConsumer(ctx, "", nil) - assert.Error(t, err) - assert.Nil(t, actual) - assert.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) + test.Error(t, err) + test.Nil(t, actual) + test.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) }) } @@ -218,7 +218,7 @@ func Test_provideSQSConsumer(T *testing.T) { t.Parallel() consumer := provideSQSConsumer(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, nil, "https://sqs.us-east-1.amazonaws.com/123/test", nil) - require.NotNil(t, consumer) + must.NotNil(t, consumer) }) T.Run("panics when NewInt64Counter fails", func(t *testing.T) { @@ -230,7 +230,7 @@ func Test_provideSQSConsumer(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { provideSQSConsumer(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) }) diff --git a/messagequeue/sqs/publisher_test.go b/messagequeue/sqs/publisher_test.go index 5426c8f..169975a 100644 --- a/messagequeue/sqs/publisher_test.go +++ b/messagequeue/sqs/publisher_test.go @@ -13,8 +13,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/aws/aws-sdk-go-v2/service/sqs" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" metricnoop "go.opentelemetry.io/otel/metric/noop" ) @@ -39,14 +39,14 @@ func Test_sqsPublisher_Publish(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*sqsPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name string `json:"name"` @@ -63,8 +63,8 @@ func Test_sqsPublisher_Publish(T *testing.T) { actual.publisher = mmp err = actual.Publish(ctx, inputData) - assert.NoError(t, err) - assert.Equal(t, 1, mmp.sendMessageCalls) + test.NoError(t, err) + test.EqOp(t, 1, mmp.sendMessageCalls) }) T.Run("with error encoding value", func(t *testing.T) { @@ -74,14 +74,14 @@ func Test_sqsPublisher_Publish(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*sqsPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name json.Number `json:"name"` @@ -90,7 +90,7 @@ func Test_sqsPublisher_Publish(T *testing.T) { } err = actual.Publish(ctx, inputData) - assert.Error(t, err) + test.Error(t, err) }) } @@ -104,14 +104,14 @@ func Test_sqsPublisher_PublishAsync(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*sqsPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name string `json:"name"` @@ -128,7 +128,7 @@ func Test_sqsPublisher_PublishAsync(T *testing.T) { actual.publisher = mmp actual.PublishAsync(ctx, inputData) - assert.Equal(t, 1, mmp.sendMessageCalls) + test.EqOp(t, 1, mmp.sendMessageCalls) }) T.Run("with error encoding value", func(t *testing.T) { @@ -138,14 +138,14 @@ func Test_sqsPublisher_PublishAsync(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*sqsPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name json.Number `json:"name"` @@ -163,14 +163,14 @@ func Test_sqsPublisher_PublishAsync(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, provider) + must.NotNil(t, provider) a, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, a) - assert.NoError(t, err) + test.NotNil(t, a) + test.NoError(t, err) actual, ok := a.(*sqsPublisher) - require.True(t, ok) + must.True(t, ok) inputData := &struct { Name string `json:"name"` @@ -187,7 +187,7 @@ func Test_sqsPublisher_PublishAsync(T *testing.T) { actual.publisher = mmp actual.PublishAsync(ctx, inputData) - assert.Equal(t, 1, mmp.sendMessageCalls) + test.EqOp(t, 1, mmp.sendMessageCalls) }) } @@ -201,7 +201,7 @@ func TestProvideSQSPublisherProvider(T *testing.T) { logger := logging.NewNoopLogger() actual := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - assert.NotNil(t, actual) + test.NotNil(t, actual) }) } @@ -215,11 +215,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with cache hit", func(t *testing.T) { @@ -229,15 +229,15 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) actual, err = provider.ProvidePublisher(ctx, t.Name()) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with empty topic", func(t *testing.T) { @@ -247,11 +247,11 @@ func Test_publisherProvider_ProvidePublisher(T *testing.T) { logger := logging.NewNoopLogger() provider := ProvideSQSPublisherProvider(ctx, logger, tracing.NewNoopTracerProvider(), nil) - require.NotNil(t, provider) + must.NotNil(t, provider) actual, err := provider.ProvidePublisher(ctx, "") - assert.Nil(t, actual) - assert.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) + test.Nil(t, actual) + test.ErrorIs(t, err, messagequeue.ErrEmptyTopicName) }) } @@ -262,7 +262,7 @@ func Test_provideSQSPublisher(T *testing.T) { t.Parallel() publisher := provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), nil, "test-topic") - require.NotNil(t, publisher) + must.NotNil(t, publisher) }) T.Run("panics when first NewInt64Counter fails", func(t *testing.T) { @@ -278,7 +278,7 @@ func Test_provideSQSPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) }) @@ -299,7 +299,7 @@ func Test_provideSQSPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) }) @@ -316,7 +316,7 @@ func Test_provideSQSPublisher(T *testing.T) { }, } - assert.Panics(t, func() { + test.Panic(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) }) diff --git a/notifications/async/ably/ably_test.go b/notifications/async/ably/ably_test.go index e9ea5a6..7ce401e 100644 --- a/notifications/async/ably/ably_test.go +++ b/notifications/async/ably/ably_test.go @@ -11,8 +11,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type mockChannelPublisher struct { @@ -32,16 +32,16 @@ func TestNewNotifier(T *testing.T) { n, err := NewNotifier(&Config{ APIKey: "appid.keyid:keysecret", }, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - require.NoError(t, err) - require.NotNil(t, n) + must.NoError(t, err) + must.NotNil(t, n) }) T.Run("nil config", func(t *testing.T) { t.Parallel() n, err := NewNotifier(nil, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.Error(t, err) - assert.Nil(t, n) + test.Error(t, err) + test.Nil(t, n) }) } @@ -74,9 +74,9 @@ func TestNotifier_Publish(T *testing.T) { Type: "greeting", Data: json.RawMessage(`{"hello":"world"}`), }) - assert.NoError(t, err) - assert.Equal(t, "my-channel", capturedChannel) - assert.Equal(t, "greeting", capturedName) + test.NoError(t, err) + test.EqOp(t, "my-channel", capturedChannel) + test.EqOp(t, "greeting", capturedName) }) T.Run("publish error", func(t *testing.T) { @@ -101,7 +101,7 @@ func TestNotifier_Publish(T *testing.T) { err := n.Publish(context.Background(), "my-channel", &async.Event{ Type: "test", }) - assert.Error(t, err) + test.Error(t, err) }) } @@ -116,6 +116,6 @@ func TestNotifier_Close(T *testing.T) { tracer: tracing.NewTracerForTest("test"), } - assert.NoError(t, n.Close()) + test.NoError(t, n.Close()) }) } diff --git a/notifications/async/ably/config_test.go b/notifications/async/ably/config_test.go index 0b0915f..193435c 100644 --- a/notifications/async/ably/config_test.go +++ b/notifications/async/ably/config_test.go @@ -3,7 +3,7 @@ package ably import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -16,7 +16,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { APIKey: "test.key:secret", } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing api key", func(t *testing.T) { @@ -24,6 +24,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/notifications/async/config/config_test.go b/notifications/async/config/config_test.go index 325e87c..4bd37c1 100644 --- a/notifications/async/config/config_test.go +++ b/notifications/async/config/config_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -24,7 +24,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderNoop, } - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("with invalid provider", func(t *testing.T) { @@ -34,7 +34,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "invalid", } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("pusher requires config", func(t *testing.T) { @@ -44,7 +44,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderPusher, } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("ably requires config", func(t *testing.T) { @@ -54,7 +54,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderAbly, } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("websocket requires config", func(t *testing.T) { @@ -64,7 +64,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderWebSocket, } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) } @@ -80,8 +80,8 @@ func TestConfig_ProvideAsyncNotifier(T *testing.T) { } actual, err := cfg.ProvideAsyncNotifier(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with sse", func(t *testing.T) { @@ -92,8 +92,8 @@ func TestConfig_ProvideAsyncNotifier(T *testing.T) { } actual, err := cfg.ProvideAsyncNotifier(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with pusher", func(t *testing.T) { @@ -110,8 +110,8 @@ func TestConfig_ProvideAsyncNotifier(T *testing.T) { } actual, err := cfg.ProvideAsyncNotifier(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) T.Run("with ably", func(t *testing.T) { @@ -125,8 +125,8 @@ func TestConfig_ProvideAsyncNotifier(T *testing.T) { } actual, err := cfg.ProvideAsyncNotifier(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) noopProviders := []string{"", ProviderNoop} @@ -139,8 +139,8 @@ func TestConfig_ProvideAsyncNotifier(T *testing.T) { } actual, err := cfg.ProvideAsyncNotifier(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) } @@ -152,8 +152,8 @@ func TestConfig_ProvideAsyncNotifier(T *testing.T) { } actual, err := cfg.ProvideAsyncNotifier(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.Nil(t, actual) - assert.Error(t, err) + test.Nil(t, actual) + test.Error(t, err) }) } @@ -168,8 +168,8 @@ func TestProvideAsyncNotifierFromConfig(T *testing.T) { } actual, err := ProvideAsyncNotifierFromConfig(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) T.Run("with unknown provider", func(t *testing.T) { @@ -180,7 +180,7 @@ func TestProvideAsyncNotifierFromConfig(T *testing.T) { } actual, err := ProvideAsyncNotifierFromConfig(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.Nil(t, actual) - assert.Error(t, err) + test.Nil(t, actual) + test.Error(t, err) }) } diff --git a/notifications/async/config/do_test.go b/notifications/async/config/do_test.go index b2c59e5..b23491e 100644 --- a/notifications/async/config/do_test.go +++ b/notifications/async/config/do_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterAsyncNotifier(T *testing.T) { @@ -28,7 +28,7 @@ func TestRegisterAsyncNotifier(T *testing.T) { RegisterAsyncNotifier(i) notifier, err := do.Invoke[async.AsyncNotifier](i) - require.NoError(t, err) - assert.NotNil(t, notifier) + must.NoError(t, err) + test.NotNil(t, notifier) }) } diff --git a/notifications/async/noop/noop_test.go b/notifications/async/noop/noop_test.go index fe93879..be314bb 100644 --- a/notifications/async/noop/noop_test.go +++ b/notifications/async/noop/noop_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/notifications/async" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestAsyncNotifier_Publish(T *testing.T) { @@ -18,14 +18,14 @@ func TestAsyncNotifier_Publish(T *testing.T) { t.Parallel() n, err := NewAsyncNotifier() - require.NoError(t, err) - require.NotNil(t, n) + must.NoError(t, err) + must.NotNil(t, n) err = n.Publish(context.Background(), "test-channel", &async.Event{ Type: "test", Data: json.RawMessage(`{"key":"value"}`), }) - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -36,9 +36,9 @@ func TestAsyncNotifier_Close(T *testing.T) { t.Parallel() n, err := NewAsyncNotifier() - require.NoError(t, err) - require.NotNil(t, n) + must.NoError(t, err) + must.NotNil(t, n) - assert.NoError(t, n.Close()) + test.NoError(t, n.Close()) }) } diff --git a/notifications/async/pusher/config_test.go b/notifications/async/pusher/config_test.go index ec4f2e3..23869db 100644 --- a/notifications/async/pusher/config_test.go +++ b/notifications/async/pusher/config_test.go @@ -3,7 +3,7 @@ package pusher import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -19,7 +19,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Cluster: "us2", } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing required fields", func(t *testing.T) { @@ -27,6 +27,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/notifications/async/pusher/pusher_test.go b/notifications/async/pusher/pusher_test.go index 357983a..a24ebd8 100644 --- a/notifications/async/pusher/pusher_test.go +++ b/notifications/async/pusher/pusher_test.go @@ -11,8 +11,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type mockPusherClient struct { @@ -35,16 +35,16 @@ func TestNewNotifier(T *testing.T) { Secret: "secret", Cluster: "us2", }, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - require.NoError(t, err) - require.NotNil(t, n) + must.NoError(t, err) + must.NotNil(t, n) }) T.Run("nil config", func(t *testing.T) { t.Parallel() n, err := NewNotifier(nil, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil) - assert.Error(t, err) - assert.Nil(t, n) + test.Error(t, err) + test.Nil(t, n) }) } @@ -77,9 +77,9 @@ func TestNotifier_Publish(T *testing.T) { Type: "greeting", Data: json.RawMessage(`{"hello":"world"}`), }) - assert.NoError(t, err) - assert.Equal(t, "my-channel", capturedChannel) - assert.Equal(t, "greeting", capturedEvent) + test.NoError(t, err) + test.EqOp(t, "my-channel", capturedChannel) + test.EqOp(t, "greeting", capturedEvent) }) T.Run("trigger error", func(t *testing.T) { @@ -104,7 +104,7 @@ func TestNotifier_Publish(T *testing.T) { err := n.Publish(context.Background(), "my-channel", &async.Event{ Type: "test", }) - assert.Error(t, err) + test.Error(t, err) }) } @@ -119,6 +119,6 @@ func TestNotifier_Close(T *testing.T) { tracer: tracing.NewTracerForTest("test"), } - assert.NoError(t, n.Close()) + test.NoError(t, n.Close()) }) } diff --git a/notifications/async/sse/config_test.go b/notifications/async/sse/config_test.go index 8bc9d87..7fa3f79 100644 --- a/notifications/async/sse/config_test.go +++ b/notifications/async/sse/config_test.go @@ -3,7 +3,7 @@ package sse import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -14,6 +14,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/notifications/async/sse/sse_test.go b/notifications/async/sse/sse_test.go index addd571..e9708bb 100644 --- a/notifications/async/sse/sse_test.go +++ b/notifications/async/sse/sse_test.go @@ -14,8 +14,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNewNotifier(T *testing.T) { @@ -25,16 +25,16 @@ func TestNewNotifier(T *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) - require.NotNil(t, n) + must.NoError(t, err) + must.NotNil(t, n) }) T.Run("nil config", func(t *testing.T) { t.Parallel() n, err := NewNotifier(nil, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) - require.NotNil(t, n) + must.NoError(t, err) + must.NotNil(t, n) }) } @@ -45,34 +45,34 @@ func TestNotifier_Publish(T *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) + must.NoError(t, err) err = n.Publish(context.Background(), "test-channel", &async.Event{ Type: "test", Data: json.RawMessage(`{"key":"value"}`), }) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("with connected client", func(t *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) + must.NoError(t, err) ready := make(chan struct{}) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { close(ready) acceptErr := n.AcceptConnection(w, r, "test-channel", "member-1") - assert.NoError(t, acceptErr) + test.NoError(t, acceptErr) })) defer server.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody) - require.NoError(t, err) + must.NoError(t, err) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + must.NoError(t, err) defer resp.Body.Close() <-ready @@ -83,7 +83,7 @@ func TestNotifier_Publish(T *testing.T) { Type: "greeting", Data: json.RawMessage(`{"hello":"world"}`), }) - assert.NoError(t, err) + test.NoError(t, err) scanner := bufio.NewScanner(resp.Body) var eventLine, dataLine string @@ -98,8 +98,8 @@ func TestNotifier_Publish(T *testing.T) { } } - assert.Contains(t, eventLine, "greeting") - assert.Contains(t, dataLine, `{"hello":"world"}`) + test.StrContains(t, eventLine, "greeting") + test.StrContains(t, dataLine, `{"hello":"world"}`) }) } @@ -110,8 +110,8 @@ func TestNotifier_Close(T *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) + must.NoError(t, err) - assert.NoError(t, n.Close()) + test.NoError(t, n.Close()) }) } diff --git a/notifications/async/websocket/config_test.go b/notifications/async/websocket/config_test.go index 64803bc..33d427c 100644 --- a/notifications/async/websocket/config_test.go +++ b/notifications/async/websocket/config_test.go @@ -3,7 +3,7 @@ package websocket import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -14,6 +14,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) } diff --git a/notifications/async/websocket/websocket_test.go b/notifications/async/websocket/websocket_test.go index 1bfc749..d397725 100644 --- a/notifications/async/websocket/websocket_test.go +++ b/notifications/async/websocket/websocket_test.go @@ -14,8 +14,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" gorillawebsocket "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNewNotifier(T *testing.T) { @@ -25,16 +25,16 @@ func TestNewNotifier(T *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) - require.NotNil(t, n) + must.NoError(t, err) + must.NotNil(t, n) }) T.Run("nil config", func(t *testing.T) { t.Parallel() n, err := NewNotifier(nil, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - assert.Error(t, err) - assert.Nil(t, n) + test.Error(t, err) + test.Nil(t, n) }) } @@ -45,24 +45,24 @@ func TestNotifier_Publish(T *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) + must.NoError(t, err) err = n.Publish(context.Background(), "test-channel", &async.Event{ Type: "test", Data: json.RawMessage(`{"key":"value"}`), }) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("with connected client", func(t *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) + must.NoError(t, err) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { acceptErr := n.AcceptConnection(w, r, "test-channel", "member-1") - assert.NoError(t, acceptErr) + test.NoError(t, acceptErr) // keep the handler alive so the websocket stays open <-r.Context().Done() })) @@ -70,7 +70,7 @@ func TestNotifier_Publish(T *testing.T) { wsURL := "ws" + strings.TrimPrefix(server.URL, "http") conn, _, err := gorillawebsocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() // give the connection time to register @@ -80,12 +80,12 @@ func TestNotifier_Publish(T *testing.T) { Type: "greeting", Data: json.RawMessage(`{"hello":"world"}`), }) - assert.NoError(t, err) + test.NoError(t, err) var received map[string]json.RawMessage err = conn.ReadJSON(&received) - require.NoError(t, err) - assert.Equal(t, json.RawMessage(`"greeting"`), received["type"]) + must.NoError(t, err) + test.Eq(t, json.RawMessage(`"greeting"`), received["type"]) }) } @@ -96,18 +96,18 @@ func TestNotifier_AcceptConnection(T *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) + must.NoError(t, err) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { acceptErr := n.AcceptConnection(w, r, "channel", "member") - assert.NoError(t, acceptErr) + test.NoError(t, acceptErr) <-r.Context().Done() })) defer server.Close() wsURL := "ws" + strings.TrimPrefix(server.URL, "http") conn, _, err := gorillawebsocket.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) + must.NoError(t, err) defer conn.Close() }) } @@ -119,8 +119,8 @@ func TestNotifier_Close(T *testing.T) { t.Parallel() n, err := NewNotifier(&Config{}, logging.NewNoopLogger(), tracing.NewNoopTracerProvider()) - require.NoError(t, err) + must.NoError(t, err) - assert.NoError(t, n.Close()) + test.NoError(t, n.Close()) }) } diff --git a/notifications/mobile/apns/apns_sender_test.go b/notifications/mobile/apns/apns_sender_test.go index dc31670..f49e083 100644 --- a/notifications/mobile/apns/apns_sender_test.go +++ b/notifications/mobile/apns/apns_sender_test.go @@ -20,8 +20,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -44,7 +43,7 @@ func createTestSenderWithTransport(t *testing.T, fn roundTripFunc) *Sender { BundleID: "com.example.app", } sender, err := NewSender(cfg, tracing.NewNoopTracerProvider(), logging.NewNoopLogger(), nil) - require.NoError(t, err) + must.NoError(t, err) sender.client.HTTPClient = &http.Client{Transport: fn} sender.client.Host = "https://test.example.com" @@ -56,10 +55,10 @@ func createTestP8File(t *testing.T) string { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) + must.NoError(t, err) keyBytes, err := x509.MarshalPKCS8PrivateKey(key) - require.NoError(t, err) + must.NoError(t, err) block := &pem.Block{ Type: "PRIVATE KEY", @@ -68,7 +67,7 @@ func createTestP8File(t *testing.T) string { dir := t.TempDir() path := filepath.Join(dir, "AuthKey.p8") - require.NoError(t, os.WriteFile(path, pem.EncodeToMemory(block), 0o600)) + must.NoError(t, os.WriteFile(path, pem.EncodeToMemory(block), 0o600)) return path } @@ -82,9 +81,9 @@ func TestNewSender(T *testing.T) { t.Parallel() sender, err := NewSender(nil, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "missing required config") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "missing required config") }) T.Run("with empty auth key path", func(t *testing.T) { @@ -97,9 +96,9 @@ func TestNewSender(T *testing.T) { Production: false, } sender, err := NewSender(cfg, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "missing required config") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "missing required config") }) T.Run("with empty key ID", func(t *testing.T) { @@ -113,9 +112,9 @@ func TestNewSender(T *testing.T) { Production: false, } sender, err := NewSender(cfg, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "missing required config") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "missing required config") }) T.Run("with non-existent auth key file", func(t *testing.T) { @@ -129,9 +128,9 @@ func TestNewSender(T *testing.T) { Production: false, } sender, err := NewSender(cfg, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "loading auth key") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "loading auth key") }) T.Run("with empty team ID", func(t *testing.T) { @@ -144,9 +143,9 @@ func TestNewSender(T *testing.T) { BundleID: "com.example.app", } sender, err := NewSender(cfg, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "missing required config") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "missing required config") }) T.Run("with empty bundle ID", func(t *testing.T) { @@ -159,9 +158,9 @@ func TestNewSender(T *testing.T) { TeamID: "TEAM123", } sender, err := NewSender(cfg, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "missing required config") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "missing required config") }) T.Run("with valid config", func(t *testing.T) { @@ -176,9 +175,9 @@ func TestNewSender(T *testing.T) { Production: false, } sender, err := NewSender(cfg, tracingProvider, logger, nil) - require.NoError(t, err) - require.NotNil(t, sender) - assert.Equal(t, "com.example.app", sender.topic) + must.NoError(t, err) + must.NotNil(t, sender) + test.EqOp(t, "com.example.app", sender.topic) }) T.Run("with production config", func(t *testing.T) { @@ -193,9 +192,9 @@ func TestNewSender(T *testing.T) { Production: true, } sender, err := NewSender(cfg, tracingProvider, logger, nil) - require.NoError(t, err) - require.NotNil(t, sender) - assert.Equal(t, "com.example.app", sender.topic) + must.NoError(t, err) + must.NotNil(t, sender) + test.EqOp(t, "com.example.app", sender.topic) }) T.Run("with send counter creation error", func(t *testing.T) { @@ -211,15 +210,15 @@ func TestNewSender(T *testing.T) { mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { - assert.Equal(t, o11yName+"_sends", counterName) + test.EqOp(t, o11yName+"_sends", counterName) return (*metrics.Int64CounterImpl)(nil), errors.New("counter error") }, } sender, err := NewSender(cfg, tracingProvider, logger, mp) - assert.Nil(t, sender) - require.Error(t, err) - assert.Contains(t, err.Error(), "creating send counter") + test.Nil(t, sender) + must.Error(t, err) + test.StrContains(t, err.Error(), "creating send counter") test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -249,9 +248,9 @@ func TestNewSender(T *testing.T) { } sender, err := NewSender(cfg, tracingProvider, logger, mp) - assert.Nil(t, sender) - require.Error(t, err) - assert.Contains(t, err.Error(), "creating error counter") + test.Nil(t, sender) + must.Error(t, err) + test.StrContains(t, err.Error(), "creating error counter") test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -274,7 +273,7 @@ func TestSender_Send(T *testing.T) { }) err := sender.Send(ctx, validDeviceToken, "title", "body", nil) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("successful push with badge count", func(t *testing.T) { @@ -290,7 +289,7 @@ func TestSender_Send(T *testing.T) { badge := 5 err := sender.Send(ctx, validDeviceToken, "title", "body", &badge) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("push returns transport error", func(t *testing.T) { @@ -301,8 +300,8 @@ func TestSender_Send(T *testing.T) { }) err := sender.Send(ctx, validDeviceToken, "title", "body", nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "push failed") + must.Error(t, err) + test.StrContains(t, err.Error(), "push failed") }) T.Run("push returns non-sent response", func(t *testing.T) { @@ -317,8 +316,8 @@ func TestSender_Send(T *testing.T) { }) err := sender.Send(ctx, validDeviceToken, "title", "body", nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "BadDeviceToken") + must.Error(t, err) + test.StrContains(t, err.Error(), "BadDeviceToken") }) } @@ -334,7 +333,7 @@ func TestSender_Send_rejectsInvalidDeviceToken(T *testing.T) { Production: false, } sender, err := NewSender(cfg, tracing.NewNoopTracerProvider(), logging.NewNoopLogger(), nil) - require.NoError(T, err) + must.NoError(T, err) ctx := T.Context() @@ -343,22 +342,22 @@ func TestSender_Send_rejectsInvalidDeviceToken(T *testing.T) { // Simulates decrypted garbage (e.g. wrong key or corrupted data) invalidToken := "x\x89\xbf\x1f\xa0\x93\x12\xf5" sendErr := sender.Send(ctx, invalidToken, "title", "body", nil) - require.Error(t, sendErr) - assert.Contains(t, sendErr.Error(), "invalid device token format") + must.Error(t, sendErr) + test.StrContains(t, sendErr.Error(), "invalid device token format") }) T.Run("rejects token with control characters", func(t *testing.T) { t.Parallel() invalidToken := "a1b2c3d4e5f6789012345678901234567890abcdef1234567890abcdef12345\t" sendErr := sender.Send(ctx, invalidToken, "title", "body", nil) - require.Error(t, sendErr) - assert.Contains(t, sendErr.Error(), "invalid device token format") + must.Error(t, sendErr) + test.StrContains(t, sendErr.Error(), "invalid device token format") }) T.Run("rejects too short token", func(t *testing.T) { t.Parallel() sendErr := sender.Send(ctx, "abc123", "title", "body", nil) - require.Error(t, sendErr) - assert.Contains(t, sendErr.Error(), "invalid device token format") + must.Error(t, sendErr) + test.StrContains(t, sendErr.Error(), "invalid device token format") }) } diff --git a/notifications/mobile/config/config_test.go b/notifications/mobile/config/config_test.go index f830259..d288d17 100644 --- a/notifications/mobile/config/config_test.go +++ b/notifications/mobile/config/config_test.go @@ -14,18 +14,18 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func createTestP8File(t *testing.T) string { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) + must.NoError(t, err) keyBytes, err := x509.MarshalPKCS8PrivateKey(key) - require.NoError(t, err) + must.NoError(t, err) block := &pem.Block{ Type: "PRIVATE KEY", @@ -34,7 +34,7 @@ func createTestP8File(t *testing.T) string { dir := t.TempDir() path := filepath.Join(dir, "AuthKey.p8") - require.NoError(t, os.WriteFile(path, pem.EncodeToMemory(block), 0o600)) + must.NoError(t, os.WriteFile(path, pem.EncodeToMemory(block), 0o600)) return path } @@ -48,7 +48,7 @@ func createTestFCMCredsFile(t *testing.T) string { dir := t.TempDir() path := filepath.Join(dir, "fcm-creds.json") - require.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) + must.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) return path } @@ -61,14 +61,14 @@ func TestConfig_ValidateWithContext(T *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderNoop} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with empty provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ""} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with apns_fcm provider and both nil", func(t *testing.T) { @@ -79,7 +79,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { APNs: nil, FCM: nil, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with apns_fcm provider and nil APNs but FCM present", func(t *testing.T) { @@ -90,7 +90,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { APNs: nil, FCM: &FCMConfig{}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with apns_fcm provider and nil FCM but APNs present", func(t *testing.T) { @@ -101,7 +101,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { APNs: &APNsConfig{AuthKeyPath: "x", KeyID: "x", TeamID: "x", BundleID: "x"}, FCM: nil, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with apns_fcm provider and both configs", func(t *testing.T) { @@ -113,7 +113,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { APNs: &APNsConfig{AuthKeyPath: p8Path, KeyID: "x", TeamID: "x", BundleID: "x"}, FCM: &FCMConfig{}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } @@ -129,10 +129,10 @@ func TestConfig_ProvidePushSender(T *testing.T) { cfg := Config{Provider: ""} sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) + must.NoError(t, err) + must.NotNil(t, sender) // Noop returns nil on SendPush - assert.NoError(t, sender.SendPush(ctx, "ios", "token", mobile.PushMessage{Title: "title", Body: "body"})) + test.NoError(t, sender.SendPush(ctx, "ios", "token", mobile.PushMessage{Title: "title", Body: "body"})) }) T.Run("with noop provider returns noop", func(t *testing.T) { @@ -140,9 +140,9 @@ func TestConfig_ProvidePushSender(T *testing.T) { cfg := Config{Provider: ProviderNoop} sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) - assert.NoError(t, sender.SendPush(ctx, "android", "token", mobile.PushMessage{Title: "title", Body: "body"})) + must.NoError(t, err) + must.NotNil(t, sender) + test.NoError(t, sender.SendPush(ctx, "android", "token", mobile.PushMessage{Title: "title", Body: "body"})) }) T.Run("with apns_fcm provider and nil APNs returns noop or FCM-only sender", func(t *testing.T) { @@ -154,8 +154,8 @@ func TestConfig_ProvidePushSender(T *testing.T) { FCM: &FCMConfig{}, } sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) + must.NoError(t, err) + must.NotNil(t, sender) // FCM init typically fails in unit tests (no ADC), so we get noop; if it succeeds, iOS would error _ = sender.SendPush(ctx, "ios", "token", mobile.PushMessage{Title: "title", Body: "body"}) }) @@ -170,12 +170,12 @@ func TestConfig_ProvidePushSender(T *testing.T) { FCM: nil, } sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) + must.NoError(t, err) + must.NotNil(t, sender) // Android not configured, should return ErrPlatformNotSupported err = sender.SendPush(ctx, "android", "token", mobile.PushMessage{Title: "title", Body: "body"}) - assert.Error(t, err) - assert.ErrorIs(t, err, mobile.ErrPlatformNotSupported) + test.Error(t, err) + test.ErrorIs(t, err, mobile.ErrPlatformNotSupported) }) T.Run("with apns_fcm provider and invalid APNs path returns noop or FCM-only sender", func(t *testing.T) { @@ -187,8 +187,8 @@ func TestConfig_ProvidePushSender(T *testing.T) { FCM: &FCMConfig{}, } sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) + must.NoError(t, err) + must.NotNil(t, sender) // APNs init fails; FCM init typically fails in unit tests, so we get noop _ = sender.SendPush(ctx, "ios", "token", mobile.PushMessage{Title: "title", Body: "body"}) }) @@ -203,12 +203,12 @@ func TestConfig_ProvidePushSender(T *testing.T) { FCM: &FCMConfig{CredentialsPath: filepath.Join(t.TempDir(), "nonexistent.json")}, } sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) + must.NoError(t, err) + must.NotNil(t, sender) // FCM init fails, so Android not configured; should return ErrPlatformNotSupported err = sender.SendPush(ctx, "android", "token", mobile.PushMessage{Title: "title", Body: "body"}) - assert.Error(t, err) - assert.ErrorIs(t, err, mobile.ErrPlatformNotSupported) + test.Error(t, err) + test.ErrorIs(t, err, mobile.ErrPlatformNotSupported) }) T.Run("with unknown provider returns noop", func(t *testing.T) { @@ -216,9 +216,9 @@ func TestConfig_ProvidePushSender(T *testing.T) { cfg := Config{Provider: "unknown"} sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) - assert.NoError(t, sender.SendPush(ctx, "ios", "token", mobile.PushMessage{Title: "title", Body: "body"})) + must.NoError(t, err) + must.NotNil(t, sender) + test.NoError(t, sender.SendPush(ctx, "ios", "token", mobile.PushMessage{Title: "title", Body: "body"})) }) T.Run("with apns_fcm provider and valid FCM creds returns multi-platform sender", func(t *testing.T) { @@ -231,8 +231,8 @@ func TestConfig_ProvidePushSender(T *testing.T) { FCM: &FCMConfig{CredentialsPath: credsPath}, } sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) + must.NoError(t, err) + must.NotNil(t, sender) }) T.Run("with apns_fcm provider and both valid configs returns multi-platform sender", func(t *testing.T) { @@ -246,7 +246,7 @@ func TestConfig_ProvidePushSender(T *testing.T) { FCM: &FCMConfig{CredentialsPath: credsPath}, } sender, err := cfg.ProvidePushSender(ctx, logger, tracer, nil) - require.NoError(t, err) - require.NotNil(t, sender) + must.NoError(t, err) + must.NotNil(t, sender) }) } diff --git a/notifications/mobile/config/do_test.go b/notifications/mobile/config/do_test.go index 9f5bdc6..8d91daf 100644 --- a/notifications/mobile/config/do_test.go +++ b/notifications/mobile/config/do_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterPushSender(T *testing.T) { @@ -30,8 +30,8 @@ func TestRegisterPushSender(T *testing.T) { RegisterPushSender(i) sender, err := do.Invoke[mobile.PushNotificationSender](i) - require.NoError(t, err) - assert.NotNil(t, sender) + must.NoError(t, err) + test.NotNil(t, sender) }) } @@ -48,7 +48,7 @@ func TestProvidePushSender(T *testing.T) { tracing.NewNoopTracerProvider(), nil, ) - require.NoError(t, err) - assert.NotNil(t, sender) + must.NoError(t, err) + test.NotNil(t, sender) }) } diff --git a/notifications/mobile/fcm/fcm_sender_test.go b/notifications/mobile/fcm/fcm_sender_test.go index b7659aa..70eddc3 100644 --- a/notifications/mobile/fcm/fcm_sender_test.go +++ b/notifications/mobile/fcm/fcm_sender_test.go @@ -16,8 +16,7 @@ import ( firebase "firebase.google.com/go/v4" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" "google.golang.org/api/option" ) @@ -38,16 +37,16 @@ func createTestFCMSenderWithTransport(t *testing.T, fn roundTripFunc) *Sender { httpClient := &http.Client{Transport: fn} app, err := firebase.NewApp(ctx, &firebase.Config{ProjectID: "test-project"}, option.WithHTTPClient(httpClient)) - require.NoError(t, err) + must.NoError(t, err) client, err := app.Messaging(ctx) - require.NoError(t, err) + must.NoError(t, err) mp := metrics.EnsureMetricsProvider(nil) sendCounter, err := mp.NewInt64Counter(o11yName + "_sends") - require.NoError(t, err) + must.NoError(t, err) errorCounter, err := mp.NewInt64Counter(o11yName + "_errors") - require.NoError(t, err) + must.NoError(t, err) return &Sender{ client: client, @@ -69,9 +68,9 @@ func TestNewSender(T *testing.T) { t.Parallel() sender, err := NewSender(ctx, nil, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "config is required") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "config is required") }) T.Run("with non-existent credentials path", func(t *testing.T) { @@ -81,9 +80,9 @@ func TestNewSender(T *testing.T) { CredentialsPath: filepath.Join(t.TempDir(), "nonexistent-firebase-credentials.json"), } sender, err := NewSender(ctx, cfg, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "credentials file not found") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "credentials file not found") }) T.Run("with empty credentials path uses ADC", func(t *testing.T) { @@ -93,12 +92,12 @@ func TestNewSender(T *testing.T) { sender, err := NewSender(ctx, cfg, tracingProvider, logger, nil) // ADC typically fails without GCP credentials in test env if err != nil { - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "fcm:") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "fcm:") return } - require.NotNil(t, sender) + must.NotNil(t, sender) }) T.Run("with invalid JSON credentials file", func(t *testing.T) { @@ -106,13 +105,13 @@ func TestNewSender(T *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "creds.json") - require.NoError(t, os.WriteFile(path, []byte("not valid json"), 0o600)) + must.NoError(t, os.WriteFile(path, []byte("not valid json"), 0o600)) cfg := &Config{CredentialsPath: path} sender, err := NewSender(ctx, cfg, tracingProvider, logger, nil) - assert.Nil(t, sender) - assert.Error(t, err) - assert.Contains(t, err.Error(), "fcm:") + test.Nil(t, sender) + test.Error(t, err) + test.StrContains(t, err.Error(), "fcm:") }) T.Run("with valid credentials file", func(t *testing.T) { @@ -120,12 +119,12 @@ func TestNewSender(T *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "creds.json") - require.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) + must.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) cfg := &Config{CredentialsPath: path} sender, err := NewSender(ctx, cfg, tracingProvider, logger, nil) - require.NoError(t, err) - require.NotNil(t, sender) + must.NoError(t, err) + must.NotNil(t, sender) }) T.Run("with send counter creation error", func(t *testing.T) { @@ -133,20 +132,20 @@ func TestNewSender(T *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "creds.json") - require.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) + must.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { - assert.Equal(t, o11yName+"_sends", counterName) + test.EqOp(t, o11yName+"_sends", counterName) return (*metrics.Int64CounterImpl)(nil), errors.New("counter error") }, } cfg := &Config{CredentialsPath: path} sender, err := NewSender(ctx, cfg, tracingProvider, logger, mp) - assert.Nil(t, sender) - require.Error(t, err) - assert.Contains(t, err.Error(), "creating send counter") + test.Nil(t, sender) + must.Error(t, err) + test.StrContains(t, err.Error(), "creating send counter") test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -156,7 +155,7 @@ func TestNewSender(T *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "creds.json") - require.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) + must.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(counterName string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { @@ -173,9 +172,9 @@ func TestNewSender(T *testing.T) { cfg := &Config{CredentialsPath: path} sender, err := NewSender(ctx, cfg, tracingProvider, logger, mp) - assert.Nil(t, sender) - require.Error(t, err) - assert.Contains(t, err.Error(), "creating error counter") + test.Nil(t, sender) + must.Error(t, err) + test.StrContains(t, err.Error(), "creating error counter") test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -198,7 +197,7 @@ func TestSender_Send(T *testing.T) { }) err := sender.Send(ctx, "device-token-abc", "Test Title", "Test Body") - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("send returns error", func(t *testing.T) { @@ -213,6 +212,6 @@ func TestSender_Send(T *testing.T) { }) err := sender.Send(ctx, "device-token-abc", "Test Title", "Test Body") - require.Error(t, err) + must.Error(t, err) }) } diff --git a/notifications/mobile/multi_platform_push_sender_test.go b/notifications/mobile/multi_platform_push_sender_test.go index 954c16e..58789af 100644 --- a/notifications/mobile/multi_platform_push_sender_test.go +++ b/notifications/mobile/multi_platform_push_sender_test.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "crypto/x509" "encoding/pem" + "errors" "os" "path/filepath" "testing" @@ -15,8 +16,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) // fakeServiceAccountJSON is a syntactically-valid Firebase service-account JSON. @@ -26,10 +27,10 @@ func createTestP8File(t *testing.T) string { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) + must.NoError(t, err) keyBytes, err := x509.MarshalPKCS8PrivateKey(key) - require.NoError(t, err) + must.NoError(t, err) block := &pem.Block{ Type: "PRIVATE KEY", @@ -38,7 +39,7 @@ func createTestP8File(t *testing.T) string { dir := t.TempDir() path := filepath.Join(dir, "AuthKey.p8") - require.NoError(t, os.WriteFile(path, pem.EncodeToMemory(block), 0o600)) + must.NoError(t, os.WriteFile(path, pem.EncodeToMemory(block), 0o600)) return path } @@ -47,7 +48,7 @@ func createTestFCMCredsFile(t *testing.T) string { dir := t.TempDir() path := filepath.Join(dir, "fcm-creds.json") - require.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) + must.NoError(t, os.WriteFile(path, []byte(fakeServiceAccountJSON), 0o600)) return path } @@ -63,8 +64,8 @@ func TestMultiPlatformPushSender_SendPush(T *testing.T) { sender := NewMultiPlatformPushSender(nil, nil, logger, tracer) err := sender.SendPush(ctx, "ios", "token", PushMessage{Title: "title", Body: "body"}) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrPlatformNotSupported) + test.Error(t, err) + test.ErrorIs(t, err, ErrPlatformNotSupported) }) T.Run("android returns ErrPlatformNotSupported when fcmSender nil", func(t *testing.T) { @@ -72,8 +73,8 @@ func TestMultiPlatformPushSender_SendPush(T *testing.T) { sender := NewMultiPlatformPushSender(nil, nil, logger, tracer) err := sender.SendPush(ctx, "android", "token", PushMessage{Title: "title", Body: "body"}) - assert.Error(t, err) - assert.ErrorIs(t, err, ErrPlatformNotSupported) + test.Error(t, err) + test.ErrorIs(t, err, ErrPlatformNotSupported) }) T.Run("unknown platform returns error", func(t *testing.T) { @@ -81,8 +82,8 @@ func TestMultiPlatformPushSender_SendPush(T *testing.T) { sender := NewMultiPlatformPushSender(nil, nil, logger, tracer) err := sender.SendPush(ctx, "unknown", "token", PushMessage{Title: "title", Body: "body"}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown platform") + test.Error(t, err) + test.StrContains(t, err.Error(), "unknown platform") }) T.Run("ios delegates to apns sender", func(t *testing.T) { @@ -96,13 +97,13 @@ func TestMultiPlatformPushSender_SendPush(T *testing.T) { BundleID: "com.example.app", } apnsSender, err := apns.NewSender(apnsCfg, tracer, logger, nil) - require.NoError(t, err) + must.NoError(t, err) sender := NewMultiPlatformPushSender(apnsSender, nil, logger, tracer) err = sender.SendPush(ctx, "ios", "not-a-valid-token", PushMessage{Title: "title", Body: "body"}) // The apns sender will reject the token format, but the delegation code path is covered. - assert.Error(t, err) - assert.NotErrorIs(t, err, ErrPlatformNotSupported) + test.Error(t, err) + test.True(t, !errors.Is(err, ErrPlatformNotSupported)) }) T.Run("android delegates to fcm sender", func(t *testing.T) { @@ -111,12 +112,12 @@ func TestMultiPlatformPushSender_SendPush(T *testing.T) { credsPath := createTestFCMCredsFile(t) fcmCfg := &fcm.Config{CredentialsPath: credsPath} fcmSender, err := fcm.NewSender(ctx, fcmCfg, tracer, logger, nil) - require.NoError(t, err) + must.NoError(t, err) sender := NewMultiPlatformPushSender(nil, fcmSender, logger, tracer) err = sender.SendPush(ctx, "android", "device-token-abc", PushMessage{Title: "title", Body: "body"}) // The fcm sender will fail at the HTTP level, but the delegation code path is covered. - assert.Error(t, err) - assert.NotErrorIs(t, err, ErrPlatformNotSupported) + test.Error(t, err) + test.True(t, !errors.Is(err, ErrPlatformNotSupported)) }) } diff --git a/numbers/numbers_test.go b/numbers/numbers_test.go index 303ac04..e63c3a4 100644 --- a/numbers/numbers_test.go +++ b/numbers/numbers_test.go @@ -3,7 +3,7 @@ package numbers import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestRoundToDecimalPlaces(T *testing.T) { @@ -15,7 +15,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(3.14159, 2) expected := float32(3.14) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("round positive number to 0 decimals", func(t *testing.T) { @@ -24,7 +24,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(3.7, 0) expected := float32(4.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("round positive number to 4 decimals", func(t *testing.T) { @@ -33,7 +33,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(3.141592653, 4) expected := float32(3.1416) - assert.InDelta(t, expected, result, 0.0001) + test.InDelta(t, expected, result, float32(0.0001)) }) T.Run("round negative number to 2 decimals", func(t *testing.T) { @@ -42,7 +42,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(-3.14159, 2) expected := float32(-3.14) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("round negative number to 0 decimals", func(t *testing.T) { @@ -51,7 +51,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(-3.7, 0) expected := float32(-4.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("round zero value", func(t *testing.T) { @@ -60,7 +60,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(0.0, 2) expected := float32(0.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("round number that needs rounding up", func(t *testing.T) { @@ -69,7 +69,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(2.555, 2) expected := float32(2.56) - assert.InDelta(t, expected, result, 0.01) + test.InDelta(t, expected, result, float32(0.01)) }) T.Run("round number that needs rounding down", func(t *testing.T) { @@ -78,7 +78,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(2.554, 2) expected := float32(2.55) - assert.InDelta(t, expected, result, 0.01) + test.InDelta(t, expected, result, float32(0.01)) }) T.Run("round large number", func(t *testing.T) { @@ -87,7 +87,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(12345.6789, 2) expected := float32(12345.68) - assert.InDelta(t, expected, result, 0.01) + test.InDelta(t, expected, result, float32(0.01)) }) T.Run("round very small number", func(t *testing.T) { @@ -96,7 +96,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(0.001234, 3) expected := float32(0.001) - assert.InDelta(t, expected, result, 0.0001) + test.InDelta(t, expected, result, float32(0.0001)) }) T.Run("round number with high precision", func(t *testing.T) { @@ -105,7 +105,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(1.23456789, 5) expected := float32(1.23457) - assert.InDelta(t, expected, result, 0.00001) + test.InDelta(t, expected, result, float32(0.00001)) }) T.Run("round negative number that needs rounding up", func(t *testing.T) { @@ -114,7 +114,7 @@ func TestRoundToDecimalPlaces(T *testing.T) { result := RoundToDecimalPlaces(-2.555, 2) expected := float32(-2.56) - assert.InDelta(t, expected, result, 0.01) + test.InDelta(t, expected, result, float32(0.01)) }) } @@ -127,7 +127,7 @@ func TestScale(T *testing.T) { result := Scale(2.5, 2.0) expected := float32(5.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("halve a quantity", func(t *testing.T) { @@ -136,7 +136,7 @@ func TestScale(T *testing.T) { result := Scale(4.0, 0.5) expected := float32(2.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale with custom precision", func(t *testing.T) { @@ -145,7 +145,7 @@ func TestScale(T *testing.T) { result := Scale(3.333, 3.0, 3) expected := float32(9.999) - assert.InDelta(t, expected, result, 0.001) + test.InDelta(t, expected, result, float32(0.001)) }) T.Run("scale with zero precision", func(t *testing.T) { @@ -154,7 +154,7 @@ func TestScale(T *testing.T) { result := Scale(2.7, 2.0, 0) expected := float32(5.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale by factor of 1", func(t *testing.T) { @@ -163,7 +163,7 @@ func TestScale(T *testing.T) { result := Scale(5.5, 1.0) expected := float32(5.5) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale zero value", func(t *testing.T) { @@ -172,7 +172,7 @@ func TestScale(T *testing.T) { result := Scale(0.0, 5.0) expected := float32(0.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale by zero factor", func(t *testing.T) { @@ -181,7 +181,7 @@ func TestScale(T *testing.T) { result := Scale(10.0, 0.0) expected := float32(0.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale with fractional factor", func(t *testing.T) { @@ -190,7 +190,7 @@ func TestScale(T *testing.T) { result := Scale(2.5, 1.5) expected := float32(3.75) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale large value", func(t *testing.T) { @@ -199,7 +199,7 @@ func TestScale(T *testing.T) { result := Scale(1000.0, 2.5) expected := float32(2500.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale with high precision", func(t *testing.T) { @@ -208,7 +208,7 @@ func TestScale(T *testing.T) { result := Scale(1.23456, 2.0, 5) expected := float32(2.46912) - assert.InDelta(t, expected, result, 0.00001) + test.InDelta(t, expected, result, float32(0.00001)) }) T.Run("scale negative value", func(t *testing.T) { @@ -217,7 +217,7 @@ func TestScale(T *testing.T) { result := Scale(-5.0, 2.0) expected := float32(-10.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) } @@ -230,7 +230,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(2.0, 4, 6) expected := float32(3.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale from 4 servings to 2 servings", func(t *testing.T) { @@ -239,7 +239,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(4.0, 4, 2) expected := float32(2.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale from 2 servings to 8 servings", func(t *testing.T) { @@ -248,7 +248,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(1.5, 2, 8) expected := float32(6.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale with same yield", func(t *testing.T) { @@ -257,7 +257,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(3.5, 4, 4) expected := float32(3.5) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale with custom precision", func(t *testing.T) { @@ -266,7 +266,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(1.0, 3, 7, 3) expected := float32(2.333) - assert.InDelta(t, expected, result, 0.001) + test.InDelta(t, expected, result, float32(0.001)) }) T.Run("scale with zero precision", func(t *testing.T) { @@ -275,7 +275,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(2.7, 4, 6, 0) expected := float32(4.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale with zero original yield returns original value", func(t *testing.T) { @@ -284,7 +284,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(5.0, 0, 10) expected := float32(5.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale with negative original yield returns original value", func(t *testing.T) { @@ -293,7 +293,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(5.0, -2, 10) expected := float32(5.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale fractional value", func(t *testing.T) { @@ -302,7 +302,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(0.5, 4, 8) expected := float32(1.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale from small to large batch", func(t *testing.T) { @@ -311,7 +311,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(0.25, 2, 12) expected := float32(1.5) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale complex ratio", func(t *testing.T) { @@ -320,7 +320,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(2.5, 6, 9) expected := float32(3.75) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) T.Run("scale with high precision for exact conversions", func(t *testing.T) { @@ -329,7 +329,7 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(1.0, 3, 5, 4) expected := float32(1.6667) - assert.InDelta(t, expected, result, 0.0001) + test.InDelta(t, expected, result, float32(0.0001)) }) T.Run("scale zero value", func(t *testing.T) { @@ -338,6 +338,6 @@ func TestScaleToYield(T *testing.T) { result := ScaleToYield(0.0, 4, 8) expected := float32(0.0) - assert.Equal(t, expected, result) + test.EqOp(t, expected, result) }) } diff --git a/observability/config_test.go b/observability/config_test.go index 67a6cfd..4d4bbc9 100644 --- a/observability/config_test.go +++ b/observability/config_test.go @@ -6,8 +6,8 @@ import ( tracingcfg "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing/config" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing/oteltrace" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -28,7 +28,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } @@ -42,11 +42,11 @@ func TestConfig_ProvidePillars(T *testing.T) { cfg := &Config{} pillars, err := cfg.ProvidePillars(ctx) - require.NoError(t, err) - require.NotNil(t, pillars) - assert.NotNil(t, pillars.Logger) - assert.NotNil(t, pillars.TracerProvider) - assert.NotNil(t, pillars.MetricsProvider) - assert.NotNil(t, pillars.Profiler) + must.NoError(t, err) + must.NotNil(t, pillars) + test.NotNil(t, pillars.Logger) + test.NotNil(t, pillars.TracerProvider) + test.NotNil(t, pillars.MetricsProvider) + test.NotNil(t, pillars.Profiler) }) } diff --git a/observability/do_test.go b/observability/do_test.go index f4fd01f..54b78f3 100644 --- a/observability/do_test.go +++ b/observability/do_test.go @@ -9,8 +9,8 @@ import ( tracingcfg "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing/config" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterO11yConfigs(T *testing.T) { @@ -27,19 +27,19 @@ func TestRegisterO11yConfigs(T *testing.T) { RegisterO11yConfigs(i) loggingConfig, err := do.Invoke[*loggingcfg.Config](i) - require.NoError(t, err) - assert.NotNil(t, loggingConfig) + must.NoError(t, err) + test.NotNil(t, loggingConfig) metricsConfig, err := do.Invoke[*metricscfg.Config](i) - require.NoError(t, err) - assert.NotNil(t, metricsConfig) + must.NoError(t, err) + test.NotNil(t, metricsConfig) tracingConfig, err := do.Invoke[*tracingcfg.Config](i) - require.NoError(t, err) - assert.NotNil(t, tracingConfig) + must.NoError(t, err) + test.NotNil(t, tracingConfig) profilingConfig, err := do.Invoke[*profilingcfg.Config](i) - require.NoError(t, err) - assert.NotNil(t, profilingConfig) + must.NoError(t, err) + test.NotNil(t, profilingConfig) }) } diff --git a/observability/errors_test.go b/observability/errors_test.go index db422db..8132cb7 100644 --- a/observability/errors_test.go +++ b/observability/errors_test.go @@ -7,7 +7,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" "google.golang.org/grpc/codes" ) @@ -22,7 +22,7 @@ func TestPrepareAndLogError(T *testing.T) { logger := logging.NewNoopLogger() _, span := tracing.StartSpan(ctx) - assert.Error(t, PrepareAndLogError(err, logger, span, "things and %s", "stuff")) + test.Error(t, PrepareAndLogError(err, logger, span, "things and %s", "stuff")) }) T.Run("with nil error", func(t *testing.T) { @@ -32,7 +32,7 @@ func TestPrepareAndLogError(T *testing.T) { logger := logging.NewNoopLogger() _, span := tracing.StartSpan(ctx) - assert.NoError(t, PrepareAndLogError(nil, logger, span, "things and %s", "stuff")) + test.NoError(t, PrepareAndLogError(nil, logger, span, "things and %s", "stuff")) }) T.Run("with nil span", func(t *testing.T) { @@ -41,7 +41,7 @@ func TestPrepareAndLogError(T *testing.T) { err := errors.New("blah") logger := logging.NewNoopLogger() - assert.Error(t, PrepareAndLogError(err, logger, nil, "things and %s", "stuff")) + test.Error(t, PrepareAndLogError(err, logger, nil, "things and %s", "stuff")) }) T.Run("with nil logger", func(t *testing.T) { @@ -51,7 +51,7 @@ func TestPrepareAndLogError(T *testing.T) { err := errors.New("blah") _, span := tracing.StartSpan(ctx) - assert.Error(t, PrepareAndLogError(err, nil, span, "things and %s", "stuff")) + test.Error(t, PrepareAndLogError(err, nil, span, "things and %s", "stuff")) }) T.Run("with empty description", func(t *testing.T) { @@ -62,7 +62,7 @@ func TestPrepareAndLogError(T *testing.T) { logger := logging.NewNoopLogger() _, span := tracing.StartSpan(ctx) - assert.Error(t, PrepareAndLogError(err, logger, span, "")) + test.Error(t, PrepareAndLogError(err, logger, span, "")) }) } @@ -76,7 +76,7 @@ func TestPrepareError(T *testing.T) { err := errors.New("blah") _, span := tracing.StartSpan(ctx) - assert.Error(t, PrepareError(err, span, "things and %s", "stuff")) + test.Error(t, PrepareError(err, span, "things and %s", "stuff")) }) T.Run("with nil error", func(t *testing.T) { @@ -85,7 +85,7 @@ func TestPrepareError(T *testing.T) { ctx := t.Context() _, span := tracing.StartSpan(ctx) - assert.NoError(t, PrepareError(nil, span, "things and %s", "stuff")) + test.NoError(t, PrepareError(nil, span, "things and %s", "stuff")) }) T.Run("with nil span", func(t *testing.T) { @@ -93,7 +93,7 @@ func TestPrepareError(T *testing.T) { err := errors.New("blah") - assert.Error(t, PrepareError(err, nil, "things and %s", "stuff")) + test.Error(t, PrepareError(err, nil, "things and %s", "stuff")) }) T.Run("with empty description", func(t *testing.T) { @@ -104,8 +104,8 @@ func TestPrepareError(T *testing.T) { _, span := tracing.StartSpan(ctx) actual := PrepareError(err, span, "") - assert.Error(t, actual) - assert.Equal(t, err, actual) + test.Error(t, actual) + test.Eq(t, err, actual) }) } @@ -163,7 +163,7 @@ func TestPrepareAndLogGRPCStatus(T *testing.T) { logger := logging.NewNoopLogger() _, span := tracing.StartSpan(ctx) - assert.Error(t, PrepareAndLogGRPCStatus(err, logger, span, codes.Internal, "things and %s", "stuff")) + test.Error(t, PrepareAndLogGRPCStatus(err, logger, span, codes.Internal, "things and %s", "stuff")) }) T.Run("with nil error", func(t *testing.T) { @@ -173,7 +173,7 @@ func TestPrepareAndLogGRPCStatus(T *testing.T) { logger := logging.NewNoopLogger() _, span := tracing.StartSpan(ctx) - assert.NoError(t, PrepareAndLogGRPCStatus(nil, logger, span, codes.Internal, "things and %s", "stuff")) + test.NoError(t, PrepareAndLogGRPCStatus(nil, logger, span, codes.Internal, "things and %s", "stuff")) }) T.Run("with nil span", func(t *testing.T) { @@ -182,7 +182,7 @@ func TestPrepareAndLogGRPCStatus(T *testing.T) { err := errors.New("blah") logger := logging.NewNoopLogger() - assert.Error(t, PrepareAndLogGRPCStatus(err, logger, nil, codes.Internal, "things and %s", "stuff")) + test.Error(t, PrepareAndLogGRPCStatus(err, logger, nil, codes.Internal, "things and %s", "stuff")) }) T.Run("with nil logger", func(t *testing.T) { @@ -192,7 +192,7 @@ func TestPrepareAndLogGRPCStatus(T *testing.T) { err := errors.New("blah") _, span := tracing.StartSpan(ctx) - assert.Error(t, PrepareAndLogGRPCStatus(err, nil, span, codes.Internal, "things and %s", "stuff")) + test.Error(t, PrepareAndLogGRPCStatus(err, nil, span, codes.Internal, "things and %s", "stuff")) }) T.Run("with empty description", func(t *testing.T) { @@ -203,6 +203,6 @@ func TestPrepareAndLogGRPCStatus(T *testing.T) { logger := logging.NewNoopLogger() _, span := tracing.StartSpan(ctx) - assert.Error(t, PrepareAndLogGRPCStatus(err, logger, span, codes.Internal, "")) + test.Error(t, PrepareAndLogGRPCStatus(err, logger, span, codes.Internal, "")) }) } diff --git a/observability/helpers_test.go b/observability/helpers_test.go index bb91130..6e41747 100644 --- a/observability/helpers_test.go +++ b/observability/helpers_test.go @@ -6,7 +6,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestObserveValues(T *testing.T) { @@ -20,7 +20,7 @@ func TestObserveValues(T *testing.T) { _, span := tracing.StartSpan(ctx) result := ObserveValues(map[string]any{"key": "value", "other": 123}, span, logger) - assert.NotNil(t, result) + test.NotNil(t, result) }) T.Run("with nil span", func(t *testing.T) { @@ -29,7 +29,7 @@ func TestObserveValues(T *testing.T) { logger := logging.NewNoopLogger() result := ObserveValues(map[string]any{"key": "value"}, nil, logger) - assert.NotNil(t, result) + test.NotNil(t, result) }) T.Run("with nil logger", func(t *testing.T) { @@ -39,14 +39,14 @@ func TestObserveValues(T *testing.T) { _, span := tracing.StartSpan(ctx) result := ObserveValues(map[string]any{"key": "value"}, span, nil) - assert.Nil(t, result) + test.Nil(t, result) }) T.Run("with nil span and nil logger", func(t *testing.T) { t.Parallel() result := ObserveValues(map[string]any{"key": "value"}, nil, nil) - assert.Nil(t, result) + test.Nil(t, result) }) T.Run("with empty values", func(t *testing.T) { @@ -57,6 +57,6 @@ func TestObserveValues(T *testing.T) { _, span := tracing.StartSpan(ctx) result := ObserveValues(map[string]any{}, span, logger) - assert.NotNil(t, result) + test.NotNil(t, result) }) } diff --git a/observability/logging/config/config_test.go b/observability/logging/config/config_test.go index 2c756d1..d50b241 100644 --- a/observability/logging/config/config_test.go +++ b/observability/logging/config/config_test.go @@ -6,8 +6,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging/otelgrpc" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) // NOTE: ValidateWithContext calls validation.ValidateStructWithContext(ctx, &cfg, ...), @@ -28,7 +28,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderZerolog, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -44,8 +44,8 @@ func TestConfig_ProvideLogger(T *testing.T) { } l, err := cfg.ProvideLogger(ctx) - assert.NoError(t, err) - assert.NotNil(t, l) + test.NoError(t, err) + test.NotNil(t, l) }) T.Run("zap provider", func(t *testing.T) { @@ -57,8 +57,8 @@ func TestConfig_ProvideLogger(T *testing.T) { } l, err := cfg.ProvideLogger(ctx) - assert.NoError(t, err) - assert.NotNil(t, l) + test.NoError(t, err) + test.NotNil(t, l) }) T.Run("slog provider", func(t *testing.T) { @@ -70,8 +70,8 @@ func TestConfig_ProvideLogger(T *testing.T) { } l, err := cfg.ProvideLogger(ctx) - assert.NoError(t, err) - assert.NotNil(t, l) + test.NoError(t, err) + test.NotNil(t, l) }) T.Run("otelslog provider", func(t *testing.T) { @@ -85,8 +85,8 @@ func TestConfig_ProvideLogger(T *testing.T) { } l, err := cfg.ProvideLogger(ctx) - assert.NoError(t, err) - assert.NotNil(t, l) + test.NoError(t, err) + test.NotNil(t, l) }) T.Run("otelslog provider with nil otelslog config returns error", func(t *testing.T) { @@ -99,8 +99,8 @@ func TestConfig_ProvideLogger(T *testing.T) { } l, err := cfg.ProvideLogger(ctx) - assert.Error(t, err) - assert.Nil(t, l) + test.Error(t, err) + test.Nil(t, l) }) T.Run("no provider falls back to noop", func(t *testing.T) { @@ -110,8 +110,8 @@ func TestConfig_ProvideLogger(T *testing.T) { cfg := &Config{} l, err := cfg.ProvideLogger(ctx) - assert.NoError(t, err) - assert.NotNil(t, l) + test.NoError(t, err) + test.NotNil(t, l) }) } @@ -127,7 +127,7 @@ func TestProvideLogger(T *testing.T) { } l, err := ProvideLogger(ctx, cfg) - require.NoError(t, err) - assert.NotNil(t, l) + must.NoError(t, err) + test.NotNil(t, l) }) } diff --git a/observability/logging/config/do_test.go b/observability/logging/config/do_test.go index 5ad60e1..e0d2d16 100644 --- a/observability/logging/config/do_test.go +++ b/observability/logging/config/do_test.go @@ -6,8 +6,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterLogger(T *testing.T) { @@ -25,7 +25,7 @@ func TestRegisterLogger(T *testing.T) { RegisterLogger(i) l, err := do.Invoke[logging.Logger](i) - require.NoError(t, err) - assert.NotNil(t, l) + must.NoError(t, err) + test.NotNil(t, l) }) } diff --git a/observability/logging/logging_test.go b/observability/logging/logging_test.go index 29d0dec..5275bb6 100644 --- a/observability/logging/logging_test.go +++ b/observability/logging/logging_test.go @@ -5,8 +5,8 @@ import ( "net/http" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop" ) @@ -18,11 +18,11 @@ func TestAllLevels(T *testing.T) { t.Parallel() levels := AllLevels() - assert.NotEmpty(t, levels) - assert.Contains(t, levels, InfoLevel) - assert.Contains(t, levels, DebugLevel) - assert.Contains(t, levels, ErrorLevel) - assert.Contains(t, levels, WarnLevel) + test.SliceNotEmpty(t, levels) + test.SliceContains(t, levels, InfoLevel) + test.SliceContains(t, levels, DebugLevel) + test.SliceContains(t, levels, ErrorLevel) + test.SliceContains(t, levels, WarnLevel) }) } @@ -32,13 +32,13 @@ func TestEnsureLogger(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, EnsureLogger(NewNoopLogger())) + test.NotNil(t, EnsureLogger(NewNoopLogger())) }) T.Run("with nil", func(t *testing.T) { t.Parallel() - assert.NotNil(t, EnsureLogger(nil)) + test.NotNil(t, EnsureLogger(nil)) }) } @@ -48,13 +48,13 @@ func TestNewNamedLogger(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNamedLogger(NewNoopLogger(), "test")) + test.NotNil(t, NewNamedLogger(NewNoopLogger(), "test")) }) T.Run("with nil logger", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNamedLogger(nil, "test")) + test.NotNil(t, NewNamedLogger(nil, "test")) }) } @@ -65,7 +65,7 @@ func TestNoopLogger(T *testing.T) { t.Parallel() l := NewNoopLogger() - assert.NotNil(t, l) + test.NotNil(t, l) }) T.Run("Info", func(t *testing.T) { @@ -95,46 +95,46 @@ func TestNoopLogger(T *testing.T) { T.Run("WithName", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNoopLogger().WithName("test")) + test.NotNil(t, NewNoopLogger().WithName("test")) }) T.Run("Clone", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNoopLogger().Clone()) + test.NotNil(t, NewNoopLogger().Clone()) }) T.Run("WithValues", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNoopLogger().WithValues(map[string]any{"key": "value"})) + test.NotNil(t, NewNoopLogger().WithValues(map[string]any{"key": "value"})) }) T.Run("WithValue", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNoopLogger().WithValue("key", "value")) + test.NotNil(t, NewNoopLogger().WithValue("key", "value")) }) T.Run("WithRequest", func(t *testing.T) { t.Parallel() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com", http.NoBody) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, NewNoopLogger().WithRequest(req)) + test.NotNil(t, NewNoopLogger().WithRequest(req)) }) T.Run("WithResponse", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNoopLogger().WithResponse(&http.Response{})) + test.NotNil(t, NewNoopLogger().WithResponse(&http.Response{})) }) T.Run("WithError", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNoopLogger().WithError(errors.New("blah"))) + test.NotNil(t, NewNoopLogger().WithError(errors.New("blah"))) }) T.Run("WithSpan", func(t *testing.T) { @@ -145,7 +145,7 @@ func TestNoopLogger(T *testing.T) { _, s := noop.NewTracerProvider().Tracer("test").Start(ctx, "test") _ = span - assert.NotNil(t, NewNoopLogger().WithSpan(s)) + test.NotNil(t, NewNoopLogger().WithSpan(s)) }) } @@ -159,8 +159,8 @@ func TestExtractSpanInfo(T *testing.T) { _, span := noop.NewTracerProvider().Tracer("test").Start(ctx, "test") info := ExtractSpanInfo(span) - assert.NotEmpty(t, info.SpanID) - assert.NotEmpty(t, info.TraceID) + test.NotEq(t, "", info.SpanID) + test.NotEq(t, "", info.TraceID) }) } @@ -171,23 +171,23 @@ func TestExtractRequestInfo(T *testing.T) { t.Parallel() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com/path?foo=bar", http.NoBody) - require.NoError(t, err) + must.NoError(t, err) info := ExtractRequestInfo(req, func(r *http.Request) string { return "req-123" }) - assert.Equal(t, http.MethodGet, info.Method) - assert.Equal(t, "/path", info.Path) - assert.Equal(t, "foo=bar", info.Query) - assert.Equal(t, "req-123", info.RequestID) + test.EqOp(t, http.MethodGet, info.Method) + test.EqOp(t, "/path", info.Path) + test.EqOp(t, "foo=bar", info.Query) + test.EqOp(t, "req-123", info.RequestID) }) T.Run("with nil request", func(t *testing.T) { t.Parallel() info := ExtractRequestInfo(nil, nil) - assert.Empty(t, info.Method) - assert.Empty(t, info.Path) - assert.Empty(t, info.Query) - assert.Empty(t, info.RequestID) + test.EqOp(t, "", info.Method) + test.EqOp(t, "", info.Path) + test.EqOp(t, "", info.Query) + test.EqOp(t, "", info.RequestID) }) T.Run("with nil URL", func(t *testing.T) { @@ -196,20 +196,20 @@ func TestExtractRequestInfo(T *testing.T) { req := &http.Request{Method: http.MethodPost} info := ExtractRequestInfo(req, nil) - assert.Equal(t, http.MethodPost, info.Method) - assert.Empty(t, info.Path) - assert.Empty(t, info.Query) + test.EqOp(t, http.MethodPost, info.Method) + test.EqOp(t, "", info.Path) + test.EqOp(t, "", info.Query) }) T.Run("with nil requestIDFunc", func(t *testing.T) { t.Parallel() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://example.com/path", http.NoBody) - require.NoError(t, err) + must.NoError(t, err) info := ExtractRequestInfo(req, nil) - assert.Equal(t, http.MethodGet, info.Method) - assert.Empty(t, info.RequestID) + test.EqOp(t, http.MethodGet, info.Method) + test.EqOp(t, "", info.RequestID) }) } diff --git a/observability/logging/otelgrpc/slog_logger_test.go b/observability/logging/otelgrpc/slog_logger_test.go index a8032b2..e5b9d1e 100644 --- a/observability/logging/otelgrpc/slog_logger_test.go +++ b/observability/logging/otelgrpc/slog_logger_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/trace" ) @@ -21,8 +21,8 @@ func TestNewLogger(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - assert.NotNil(t, l) - assert.NoError(t, err) + test.NotNil(t, l) + test.NoError(t, err) }) T.Run("with nil config", func(t *testing.T) { @@ -30,8 +30,8 @@ func TestNewLogger(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), nil) - assert.Nil(t, l) - assert.Error(t, err) + test.Nil(t, l) + test.Error(t, err) }) T.Run("with info level", func(t *testing.T) { @@ -39,8 +39,8 @@ func TestNewLogger(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.InfoLevel, t.Name(), &Config{}) - assert.NotNil(t, l) - assert.NoError(t, err) + test.NotNil(t, l) + test.NoError(t, err) }) T.Run("with warn level", func(t *testing.T) { @@ -48,8 +48,8 @@ func TestNewLogger(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.WarnLevel, t.Name(), &Config{}) - assert.NotNil(t, l) - assert.NoError(t, err) + test.NotNil(t, l) + test.NoError(t, err) }) T.Run("with error level", func(t *testing.T) { @@ -57,8 +57,8 @@ func TestNewLogger(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.ErrorLevel, t.Name(), &Config{}) - assert.NotNil(t, l) - assert.NoError(t, err) + test.NotNil(t, l) + test.NoError(t, err) }) T.Run("with collector endpoint", func(t *testing.T) { @@ -71,8 +71,8 @@ func TestNewLogger(T *testing.T) { } l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), cfg) - assert.NotNil(t, l) - assert.NoError(t, err) + test.NotNil(t, l) + test.NoError(t, err) }) } @@ -84,9 +84,9 @@ func Test_zerologLogger_WithName(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithName(t.Name())) + test.NotNil(t, l.WithName(t.Name())) }) } @@ -98,7 +98,7 @@ func Test_zerologLogger_SetRequestIDFunc(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) l.SetRequestIDFunc(func(*http.Request) string { return "" @@ -110,7 +110,7 @@ func Test_zerologLogger_SetRequestIDFunc(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) l.SetRequestIDFunc(nil) }) @@ -124,7 +124,7 @@ func Test_zerologLogger_Info(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) l.Info(t.Name()) }) @@ -138,7 +138,7 @@ func Test_zerologLogger_Debug(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) l.Debug(t.Name()) }) @@ -152,7 +152,7 @@ func Test_zerologLogger_Error(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) l.Error(t.Name(), errors.New("blah")) }) @@ -162,7 +162,7 @@ func Test_zerologLogger_Error(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) l.Error(t.Name(), nil) }) @@ -176,9 +176,9 @@ func Test_zerologLogger_Clone(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.Clone()) + test.NotNil(t, l.Clone()) }) } @@ -190,9 +190,9 @@ func Test_zerologLogger_WithValue(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithValue("name", t.Name())) + test.NotNil(t, l.WithValue("name", t.Name())) }) } @@ -204,9 +204,9 @@ func Test_zerologLogger_WithValues(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithValues(map[string]any{"name": t.Name()})) + test.NotNil(t, l.WithValues(map[string]any{"name": t.Name()})) }) } @@ -218,9 +218,9 @@ func Test_zerologLogger_WithError(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithError(errors.New("blah"))) + test.NotNil(t, l.WithError(errors.New("blah"))) }) } @@ -232,11 +232,11 @@ func Test_zerologLogger_WithSpan(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) span := trace.SpanFromContext(ctx) - assert.NotNil(t, l.WithSpan(span)) + test.NotNil(t, l.WithSpan(span)) }) } @@ -248,19 +248,19 @@ func Test_zerologLogger_WithRequest(T *testing.T) { ctx := t.Context() logger, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) l, ok := logger.(*otelSlogLogger) - require.True(t, ok) + must.True(t, ok) l.requestIDFunc = func(*http.Request) string { return t.Name() } u, err := url.ParseRequestURI("https://whatever.whocares.gov/path?things=stuff") - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithRequest(&http.Request{ + test.NotNil(t, l.WithRequest(&http.Request{ URL: u, })) }) @@ -270,9 +270,9 @@ func Test_zerologLogger_WithRequest(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithRequest(nil)) + test.NotNil(t, l.WithRequest(nil)) }) } @@ -284,9 +284,9 @@ func Test_zerologLogger_WithResponse(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithResponse(&http.Response{})) + test.NotNil(t, l.WithResponse(&http.Response{})) }) T.Run("with nil response", func(t *testing.T) { @@ -294,9 +294,9 @@ func Test_zerologLogger_WithResponse(T *testing.T) { ctx := t.Context() l, err := NewOtelSlogLogger(ctx, logging.DebugLevel, t.Name(), &Config{}) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithResponse(nil)) + test.NotNil(t, l.WithResponse(nil)) }) } @@ -313,6 +313,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { // NOTE: ValidateWithContext uses &c (double pointer) which causes // ozzo-validation to reject it. This exercises the code path regardless. - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/observability/logging/slog/slog_logger_test.go b/observability/logging/slog/slog_logger_test.go index 7ec3420..981315d 100644 --- a/observability/logging/slog/slog_logger_test.go +++ b/observability/logging/slog/slog_logger_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/trace" ) @@ -19,25 +19,25 @@ func TestNewLogger(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewSlogLogger(logging.DebugLevel)) + test.NotNil(t, NewSlogLogger(logging.DebugLevel)) }) T.Run("with info level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewSlogLogger(logging.InfoLevel)) + test.NotNil(t, NewSlogLogger(logging.InfoLevel)) }) T.Run("with warn level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewSlogLogger(logging.WarnLevel)) + test.NotNil(t, NewSlogLogger(logging.WarnLevel)) }) T.Run("with error level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewSlogLogger(logging.ErrorLevel)) + test.NotNil(t, NewSlogLogger(logging.ErrorLevel)) }) } @@ -49,7 +49,7 @@ func Test_zerologLogger_WithName(T *testing.T) { l := NewSlogLogger(logging.DebugLevel) - assert.NotNil(t, l.WithName(t.Name())) + test.NotNil(t, l.WithName(t.Name())) }) } @@ -127,7 +127,7 @@ func Test_zerologLogger_Clone(T *testing.T) { l := NewSlogLogger(logging.DebugLevel) - assert.NotNil(t, l.Clone()) + test.NotNil(t, l.Clone()) }) } @@ -139,7 +139,7 @@ func Test_zerologLogger_WithValue(T *testing.T) { l := NewSlogLogger(logging.DebugLevel) - assert.NotNil(t, l.WithValue("name", t.Name())) + test.NotNil(t, l.WithValue("name", t.Name())) }) } @@ -151,7 +151,7 @@ func Test_zerologLogger_WithValues(T *testing.T) { l := NewSlogLogger(logging.DebugLevel) - assert.NotNil(t, l.WithValues(map[string]any{"name": t.Name()})) + test.NotNil(t, l.WithValues(map[string]any{"name": t.Name()})) }) } @@ -163,7 +163,7 @@ func Test_zerologLogger_WithError(T *testing.T) { l := NewSlogLogger(logging.DebugLevel) - assert.NotNil(t, l.WithError(errors.New("blah"))) + test.NotNil(t, l.WithError(errors.New("blah"))) }) } @@ -178,7 +178,7 @@ func Test_zerologLogger_WithSpan(T *testing.T) { span := trace.SpanFromContext(ctx) - assert.NotNil(t, l.WithSpan(span)) + test.NotNil(t, l.WithSpan(span)) }) } @@ -189,16 +189,16 @@ func Test_zerologLogger_WithRequest(T *testing.T) { t.Parallel() l, ok := NewSlogLogger(logging.DebugLevel).(*slogLogger) - require.True(t, ok) + must.True(t, ok) l.requestIDFunc = func(*http.Request) string { return t.Name() } u, err := url.ParseRequestURI("https://whatever.whocares.gov/path?things=stuff") - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithRequest(&http.Request{ + test.NotNil(t, l.WithRequest(&http.Request{ URL: u, })) }) @@ -208,7 +208,7 @@ func Test_zerologLogger_WithRequest(T *testing.T) { l := NewSlogLogger(logging.DebugLevel) - assert.NotNil(t, l.WithRequest(nil)) + test.NotNil(t, l.WithRequest(nil)) }) } @@ -220,7 +220,7 @@ func Test_zerologLogger_WithResponse(T *testing.T) { l := NewSlogLogger(logging.DebugLevel) - assert.NotNil(t, l.WithResponse(&http.Response{})) + test.NotNil(t, l.WithResponse(&http.Response{})) }) T.Run("with nil response", func(t *testing.T) { @@ -228,6 +228,6 @@ func Test_zerologLogger_WithResponse(T *testing.T) { l := NewSlogLogger(logging.DebugLevel) - assert.NotNil(t, l.WithResponse(nil)) + test.NotNil(t, l.WithResponse(nil)) }) } diff --git a/observability/logging/zap/zap_logger_test.go b/observability/logging/zap/zap_logger_test.go index 76150f3..03f90ae 100644 --- a/observability/logging/zap/zap_logger_test.go +++ b/observability/logging/zap/zap_logger_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/trace" ) @@ -19,25 +19,25 @@ func TestNewLogger(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewZapLogger(logging.DebugLevel)) + test.NotNil(t, NewZapLogger(logging.DebugLevel)) }) T.Run("with info level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewZapLogger(logging.InfoLevel)) + test.NotNil(t, NewZapLogger(logging.InfoLevel)) }) T.Run("with warn level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewZapLogger(logging.WarnLevel)) + test.NotNil(t, NewZapLogger(logging.WarnLevel)) }) T.Run("with error level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewZapLogger(logging.ErrorLevel)) + test.NotNil(t, NewZapLogger(logging.ErrorLevel)) }) } @@ -49,7 +49,7 @@ func Test_zapLogger_WithName(T *testing.T) { l := NewZapLogger(logging.DebugLevel) - assert.NotNil(t, l.WithName(t.Name())) + test.NotNil(t, l.WithName(t.Name())) }) } @@ -60,7 +60,7 @@ func Test_zapLogger_SetLevel(T *testing.T) { t.Parallel() l, ok := NewZapLogger(logging.DebugLevel).(*zapLogger) - require.True(t, ok) + must.True(t, ok) l.SetLevel(logging.InfoLevel) }) @@ -69,7 +69,7 @@ func Test_zapLogger_SetLevel(T *testing.T) { t.Parallel() l, ok := NewZapLogger(logging.DebugLevel).(*zapLogger) - require.True(t, ok) + must.True(t, ok) l.SetLevel(logging.DebugLevel) }) @@ -78,7 +78,7 @@ func Test_zapLogger_SetLevel(T *testing.T) { t.Parallel() l, ok := NewZapLogger(logging.DebugLevel).(*zapLogger) - require.True(t, ok) + must.True(t, ok) l.SetLevel(logging.WarnLevel) }) @@ -87,7 +87,7 @@ func Test_zapLogger_SetLevel(T *testing.T) { t.Parallel() l, ok := NewZapLogger(logging.DebugLevel).(*zapLogger) - require.True(t, ok) + must.True(t, ok) l.SetLevel(logging.ErrorLevel) }) @@ -96,7 +96,7 @@ func Test_zapLogger_SetLevel(T *testing.T) { t.Parallel() l, ok := NewZapLogger(logging.DebugLevel).(*zapLogger) - require.True(t, ok) + must.True(t, ok) l.SetLevel(nil) }) @@ -176,7 +176,7 @@ func Test_zapLogger_Clone(T *testing.T) { l := NewZapLogger(logging.DebugLevel) - assert.NotNil(t, l.Clone()) + test.NotNil(t, l.Clone()) }) } @@ -188,7 +188,7 @@ func Test_zapLogger_WithValue(T *testing.T) { l := NewZapLogger(logging.DebugLevel) - assert.NotNil(t, l.WithValue("name", t.Name())) + test.NotNil(t, l.WithValue("name", t.Name())) }) } @@ -200,7 +200,7 @@ func Test_zapLogger_WithValues(T *testing.T) { l := NewZapLogger(logging.DebugLevel) - assert.NotNil(t, l.WithValues(map[string]any{"name": t.Name()})) + test.NotNil(t, l.WithValues(map[string]any{"name": t.Name()})) }) } @@ -212,7 +212,7 @@ func Test_zapLogger_WithError(T *testing.T) { l := NewZapLogger(logging.DebugLevel) - assert.NotNil(t, l.WithError(errors.New("blah"))) + test.NotNil(t, l.WithError(errors.New("blah"))) }) } @@ -227,7 +227,7 @@ func Test_zapLogger_WithSpan(T *testing.T) { span := trace.SpanFromContext(ctx) - assert.NotNil(t, l.WithSpan(span)) + test.NotNil(t, l.WithSpan(span)) }) } @@ -238,16 +238,16 @@ func Test_zapLogger_WithRequest(T *testing.T) { t.Parallel() l, ok := NewZapLogger(logging.DebugLevel).(*zapLogger) - require.True(t, ok) + must.True(t, ok) l.requestIDFunc = func(*http.Request) string { return t.Name() } u, err := url.ParseRequestURI("https://whatever.whocares.gov/path?things=stuff") - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithRequest(&http.Request{ + test.NotNil(t, l.WithRequest(&http.Request{ URL: u, })) }) @@ -257,7 +257,7 @@ func Test_zapLogger_WithRequest(T *testing.T) { l := NewZapLogger(logging.DebugLevel) - assert.NotNil(t, l.WithRequest(nil)) + test.NotNil(t, l.WithRequest(nil)) }) } @@ -269,7 +269,7 @@ func Test_zapLogger_WithResponse(T *testing.T) { l := NewZapLogger(logging.DebugLevel) - assert.NotNil(t, l.WithResponse(&http.Response{})) + test.NotNil(t, l.WithResponse(&http.Response{})) }) T.Run("with nil response", func(t *testing.T) { @@ -277,6 +277,6 @@ func Test_zapLogger_WithResponse(T *testing.T) { l := NewZapLogger(logging.DebugLevel) - assert.NotNil(t, l.WithResponse(nil)) + test.NotNil(t, l.WithResponse(nil)) }) } diff --git a/observability/logging/zerolog/zerolog_logger_test.go b/observability/logging/zerolog/zerolog_logger_test.go index 039c3a9..d40e5ee 100644 --- a/observability/logging/zerolog/zerolog_logger_test.go +++ b/observability/logging/zerolog/zerolog_logger_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/trace" ) @@ -19,31 +19,31 @@ func Test_buildZerologger(T *testing.T) { T.Run("with debug level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, buildZerologger(logging.DebugLevel)) + test.NotNil(t, buildZerologger(logging.DebugLevel)) }) T.Run("with info level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, buildZerologger(logging.InfoLevel)) + test.NotNil(t, buildZerologger(logging.InfoLevel)) }) T.Run("with warn level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, buildZerologger(logging.WarnLevel)) + test.NotNil(t, buildZerologger(logging.WarnLevel)) }) T.Run("with error level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, buildZerologger(logging.ErrorLevel)) + test.NotNil(t, buildZerologger(logging.ErrorLevel)) }) T.Run("with nil level", func(t *testing.T) { t.Parallel() - assert.NotNil(t, buildZerologger(nil)) + test.NotNil(t, buildZerologger(nil)) }) } @@ -53,7 +53,7 @@ func TestNewLogger(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewZerologLogger(logging.DebugLevel)) + test.NotNil(t, NewZerologLogger(logging.DebugLevel)) }) } @@ -65,7 +65,7 @@ func Test_zerologLogger_WithName(T *testing.T) { l := NewZerologLogger(logging.DebugLevel) - assert.NotNil(t, l.WithName(t.Name())) + test.NotNil(t, l.WithName(t.Name())) }) } @@ -143,7 +143,7 @@ func Test_zerologLogger_Clone(T *testing.T) { l := NewZerologLogger(logging.DebugLevel) - assert.NotNil(t, l.Clone()) + test.NotNil(t, l.Clone()) }) } @@ -155,7 +155,7 @@ func Test_zerologLogger_WithValue(T *testing.T) { l := NewZerologLogger(logging.DebugLevel) - assert.NotNil(t, l.WithValue("name", t.Name())) + test.NotNil(t, l.WithValue("name", t.Name())) }) } @@ -167,7 +167,7 @@ func Test_zerologLogger_WithValues(T *testing.T) { l := NewZerologLogger(logging.DebugLevel) - assert.NotNil(t, l.WithValues(map[string]any{"name": t.Name()})) + test.NotNil(t, l.WithValues(map[string]any{"name": t.Name()})) }) } @@ -179,7 +179,7 @@ func Test_zerologLogger_WithError(T *testing.T) { l := NewZerologLogger(logging.DebugLevel) - assert.NotNil(t, l.WithError(errors.New("blah"))) + test.NotNil(t, l.WithError(errors.New("blah"))) }) } @@ -194,7 +194,7 @@ func Test_zerologLogger_WithSpan(T *testing.T) { span := trace.SpanFromContext(ctx) - assert.NotNil(t, l.WithSpan(span)) + test.NotNil(t, l.WithSpan(span)) }) } @@ -205,16 +205,16 @@ func Test_zerologLogger_WithRequest(T *testing.T) { t.Parallel() l, ok := NewZerologLogger(logging.DebugLevel).(*zerologLogger) - require.True(t, ok) + must.True(t, ok) l.requestIDFunc = func(*http.Request) string { return t.Name() } u, err := url.ParseRequestURI("https://whatever.whocares.gov/path?things=stuff") - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, l.WithRequest(&http.Request{ + test.NotNil(t, l.WithRequest(&http.Request{ URL: u, })) }) @@ -224,7 +224,7 @@ func Test_zerologLogger_WithRequest(T *testing.T) { l := NewZerologLogger(logging.DebugLevel) - assert.NotNil(t, l.WithRequest(nil)) + test.NotNil(t, l.WithRequest(nil)) }) } @@ -236,7 +236,7 @@ func Test_zerologLogger_WithResponse(T *testing.T) { l := NewZerologLogger(logging.DebugLevel) - assert.NotNil(t, l.WithResponse(&http.Response{})) + test.NotNil(t, l.WithResponse(&http.Response{})) }) T.Run("with nil response", func(t *testing.T) { @@ -244,6 +244,6 @@ func Test_zerologLogger_WithResponse(T *testing.T) { l := NewZerologLogger(logging.DebugLevel) - assert.NotNil(t, l.WithResponse(nil)) + test.NotNil(t, l.WithResponse(nil)) }) } diff --git a/observability/metrics/config/config_test.go b/observability/metrics/config/config_test.go index 6340449..df2fe0a 100644 --- a/observability/metrics/config/config_test.go +++ b/observability/metrics/config/config_test.go @@ -7,7 +7,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/otelgrpc" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ProvideMetricsProvider(T *testing.T) { @@ -19,8 +19,8 @@ func TestConfig_ProvideMetricsProvider(T *testing.T) { cfg := &Config{} metricsProvider, err := cfg.ProvideMetricsProvider(t.Context(), logging.NewNoopLogger()) - assert.NoError(t, err) - assert.NotNil(t, metricsProvider) + test.NoError(t, err) + test.NotNil(t, metricsProvider) }) T.Run("enabled with otel provider", func(t *testing.T) { @@ -39,8 +39,8 @@ func TestConfig_ProvideMetricsProvider(T *testing.T) { metricsProvider, err := cfg.ProvideMetricsProvider(t.Context(), logging.NewNoopLogger()) - assert.NoError(t, err) - assert.NotNil(t, metricsProvider) + test.NoError(t, err) + test.NotNil(t, metricsProvider) }) T.Run("enabled with unknown provider falls back to noop", func(t *testing.T) { @@ -53,8 +53,8 @@ func TestConfig_ProvideMetricsProvider(T *testing.T) { metricsProvider, err := cfg.ProvideMetricsProvider(t.Context(), logging.NewNoopLogger()) - assert.NoError(t, err) - assert.NotNil(t, metricsProvider) + test.NoError(t, err) + test.NotNil(t, metricsProvider) }) T.Run("not enabled returns noop", func(t *testing.T) { @@ -66,8 +66,8 @@ func TestConfig_ProvideMetricsProvider(T *testing.T) { metricsProvider, err := cfg.ProvideMetricsProvider(t.Context(), logging.NewNoopLogger()) - assert.NoError(t, err) - assert.NotNil(t, metricsProvider) + test.NoError(t, err) + test.NotNil(t, metricsProvider) }) } @@ -86,7 +86,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("disabled is valid", func(t *testing.T) { @@ -96,7 +96,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Enabled: false, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("enabled with invalid provider", func(t *testing.T) { @@ -107,7 +107,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "bogus", } - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("enabled with otel provider but nil otel config", func(t *testing.T) { @@ -119,7 +119,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Otel: nil, } - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } @@ -132,7 +132,7 @@ func TestProvideMetricsProvider(T *testing.T) { cfg := &Config{} metricsProvider, err := ProvideMetricsProvider(t.Context(), logging.NewNoopLogger(), cfg) - assert.NoError(t, err) - assert.NotNil(t, metricsProvider) + test.NoError(t, err) + test.NotNil(t, metricsProvider) }) } diff --git a/observability/metrics/config/do_test.go b/observability/metrics/config/do_test.go index 400a31a..9c6f138 100644 --- a/observability/metrics/config/do_test.go +++ b/observability/metrics/config/do_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterMetricsProvider(T *testing.T) { @@ -26,7 +26,7 @@ func TestRegisterMetricsProvider(T *testing.T) { RegisterMetricsProvider(i) mp, err := do.Invoke[metrics.Provider](i) - require.NoError(t, err) - assert.NotNil(t, mp) + must.NoError(t, err) + test.NotNil(t, mp) }) } diff --git a/observability/metrics/metrics_test.go b/observability/metrics/metrics_test.go index 32a9e37..2c2875e 100644 --- a/observability/metrics/metrics_test.go +++ b/observability/metrics/metrics_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel" ) @@ -17,14 +17,14 @@ func TestEnsureMetricsProvider(T *testing.T) { p := NewNoopMetricsProvider() actual := EnsureMetricsProvider(p) - assert.Equal(t, p, actual) + test.Eq(t, p, actual) }) T.Run("returns noop provider when nil", func(t *testing.T) { t.Parallel() actual := EnsureMetricsProvider(nil) - assert.NotNil(t, actual) + test.NotNil(t, actual) }) } @@ -35,8 +35,8 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() c, err := p.NewFloat64Counter("test_counter") - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) c.Add(context.Background(), 1.0) }) @@ -44,8 +44,8 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() g, err := p.NewFloat64Gauge("test_gauge") - require.NoError(t, err) - assert.NotNil(t, g) + must.NoError(t, err) + test.NotNil(t, g) g.Record(context.Background(), 1.0) }) @@ -53,8 +53,8 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() c, err := p.NewFloat64UpDownCounter("test_updown") - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) c.Add(context.Background(), -1.0) }) @@ -62,8 +62,8 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() h, err := p.NewFloat64Histogram("test_histogram") - require.NoError(t, err) - assert.NotNil(t, h) + must.NoError(t, err) + test.NotNil(t, h) h.Record(context.Background(), 1.0) }) @@ -71,8 +71,8 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() c, err := p.NewInt64Counter("test_counter") - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) c.Add(context.Background(), 1) }) @@ -80,8 +80,8 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() g, err := p.NewInt64Gauge("test_gauge") - require.NoError(t, err) - assert.NotNil(t, g) + must.NoError(t, err) + test.NotNil(t, g) g.Record(context.Background(), 1) }) @@ -89,8 +89,8 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() c, err := p.NewInt64UpDownCounter("test_updown") - require.NoError(t, err) - assert.NotNil(t, c) + must.NoError(t, err) + test.NotNil(t, c) c.Add(context.Background(), -1) }) @@ -98,8 +98,8 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() h, err := p.NewInt64Histogram("test_histogram") - require.NoError(t, err) - assert.NotNil(t, h) + must.NoError(t, err) + test.NotNil(t, h) h.Record(context.Background(), 1) }) @@ -107,14 +107,14 @@ func TestNoopProvider(T *testing.T) { t.Parallel() p := NewNoopMetricsProvider() err := p.Shutdown(context.Background()) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("MeterProvider", func(t *testing.T) { t.Parallel() p := NewNoopMetricsProvider() mp := p.MeterProvider() - assert.NotNil(t, mp) + test.NotNil(t, mp) }) } @@ -124,7 +124,7 @@ func TestInt64CounterForTest(T *testing.T) { T.Run("returns a counter", func(t *testing.T) { t.Parallel() c := Int64CounterForTest(t, "test_counter") - assert.NotNil(t, c) + test.NotNil(t, c) }) } @@ -137,7 +137,7 @@ func TestImplWrappers(T *testing.T) { T.Run("Float64CounterImpl", func(t *testing.T) { t.Parallel() x, err := meter.Float64Counter("test_f64_counter") - require.NoError(t, err) + must.NoError(t, err) impl := &Float64CounterImpl{X: x} impl.Add(ctx, 1.0) }) @@ -145,7 +145,7 @@ func TestImplWrappers(T *testing.T) { T.Run("Float64GaugeImpl", func(t *testing.T) { t.Parallel() x, err := meter.Float64Gauge("test_f64_gauge") - require.NoError(t, err) + must.NoError(t, err) impl := &Float64GaugeImpl{X: x} impl.Record(ctx, 1.0) }) @@ -153,7 +153,7 @@ func TestImplWrappers(T *testing.T) { T.Run("Float64UpDownCounterImpl", func(t *testing.T) { t.Parallel() x, err := meter.Float64UpDownCounter("test_f64_updown") - require.NoError(t, err) + must.NoError(t, err) impl := &Float64UpDownCounterImpl{X: x} impl.Add(ctx, -1.0) }) @@ -161,7 +161,7 @@ func TestImplWrappers(T *testing.T) { T.Run("Float64HistogramImpl", func(t *testing.T) { t.Parallel() x, err := meter.Float64Histogram("test_f64_histogram") - require.NoError(t, err) + must.NoError(t, err) impl := &Float64HistogramImpl{X: x} impl.Record(ctx, 1.0) }) @@ -169,7 +169,7 @@ func TestImplWrappers(T *testing.T) { T.Run("Int64CounterImpl", func(t *testing.T) { t.Parallel() x, err := meter.Int64Counter("test_i64_counter") - require.NoError(t, err) + must.NoError(t, err) impl := &Int64CounterImpl{X: x} impl.Add(ctx, 1) }) @@ -177,7 +177,7 @@ func TestImplWrappers(T *testing.T) { T.Run("Int64GaugeImpl", func(t *testing.T) { t.Parallel() x, err := meter.Int64Gauge("test_i64_gauge") - require.NoError(t, err) + must.NoError(t, err) impl := &Int64GaugeImpl{X: x} impl.Record(ctx, 1) }) @@ -185,7 +185,7 @@ func TestImplWrappers(T *testing.T) { T.Run("Int64UpDownCounterImpl", func(t *testing.T) { t.Parallel() x, err := meter.Int64UpDownCounter("test_i64_updown") - require.NoError(t, err) + must.NoError(t, err) impl := &Int64UpDownCounterImpl{X: x} impl.Add(ctx, -1) }) @@ -193,7 +193,7 @@ func TestImplWrappers(T *testing.T) { T.Run("Int64HistogramImpl", func(t *testing.T) { t.Parallel() x, err := meter.Int64Histogram("test_i64_histogram") - require.NoError(t, err) + must.NoError(t, err) impl := &Int64HistogramImpl{X: x} impl.Record(ctx, 1) }) diff --git a/observability/metrics/otelgrpc/config/config_test.go b/observability/metrics/otelgrpc/config/config_test.go index bb3742e..0ea64d8 100644 --- a/observability/metrics/otelgrpc/config/config_test.go +++ b/observability/metrics/otelgrpc/config/config_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/otelgrpc" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -23,7 +23,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ServiceName: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("missing collector endpoint", func(t *testing.T) { @@ -34,7 +34,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ServiceName: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("missing service name", func(t *testing.T) { @@ -45,7 +45,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { CollectorEndpoint: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -66,8 +66,8 @@ func TestProvideMetricsProvider(T *testing.T) { } provider, err := ProvideMetricsProvider(t.Context(), logging.NewNoopLogger(), cfg) - require.NoError(t, err) - assert.NotNil(t, provider) + must.NoError(t, err) + test.NotNil(t, provider) }) T.Run("with nil otel config", func(t *testing.T) { @@ -79,7 +79,7 @@ func TestProvideMetricsProvider(T *testing.T) { } provider, err := ProvideMetricsProvider(t.Context(), logging.NewNoopLogger(), cfg) - assert.Nil(t, provider) - assert.Error(t, err) + test.Nil(t, provider) + test.Error(t, err) }) } diff --git a/observability/metrics/otelgrpc/config/do_test.go b/observability/metrics/otelgrpc/config/do_test.go index dfb8e69..0403659 100644 --- a/observability/metrics/otelgrpc/config/do_test.go +++ b/observability/metrics/otelgrpc/config/do_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/otelgrpc" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterMetricsProvider(T *testing.T) { @@ -35,7 +35,7 @@ func TestRegisterMetricsProvider(T *testing.T) { RegisterMetricsProvider(i) provider, err := do.Invoke[metrics.Provider](i) - require.NoError(t, err) - assert.NotNil(t, provider) + must.NoError(t, err) + test.NotNil(t, provider) }) } diff --git a/observability/metrics/otelgrpc/metrics_test.go b/observability/metrics/otelgrpc/metrics_test.go index 964d0d9..30fe96f 100644 --- a/observability/metrics/otelgrpc/metrics_test.go +++ b/observability/metrics/otelgrpc/metrics_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -23,7 +23,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(t.Context()) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("missing collector endpoint", func(t *testing.T) { @@ -34,8 +34,8 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(t.Context()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "metricsCollectorEndpoint") + test.Error(t, err) + test.StrContains(t, err.Error(), "metricsCollectorEndpoint") }) T.Run("missing collection interval", func(t *testing.T) { @@ -46,8 +46,8 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(t.Context()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "collectionInterval") + test.Error(t, err) + test.StrContains(t, err.Error(), "collectionInterval") }) T.Run("empty collector endpoint", func(t *testing.T) { @@ -59,8 +59,8 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(t.Context()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "metricsCollectorEndpoint") + test.Error(t, err) + test.StrContains(t, err.Error(), "metricsCollectorEndpoint") }) } @@ -74,10 +74,10 @@ func TestSetupMetricsProvider(T *testing.T) { logger := logging.NewNoopLogger() provider, shutdown, err := setupMetricsProvider(ctx, logger, "test-service", nil) - assert.Nil(t, provider) - assert.Nil(t, shutdown) - assert.Error(t, err) - assert.Equal(t, ErrNilConfig, err) + test.Nil(t, provider) + test.Nil(t, shutdown) + test.Error(t, err) + test.ErrorIs(t, err, ErrNilConfig) }) T.Run("valid config", func(t *testing.T) { @@ -94,9 +94,9 @@ func TestSetupMetricsProvider(T *testing.T) { } provider, shutdown, err := setupMetricsProvider(ctx, logger, "test-service", cfg) - assert.NoError(t, err) - assert.NotNil(t, provider) - assert.NotNil(t, shutdown) + test.NoError(t, err) + test.NotNil(t, provider) + test.NotNil(t, shutdown) }) T.Run("with runtime metrics enabled", func(t *testing.T) { @@ -113,9 +113,9 @@ func TestSetupMetricsProvider(T *testing.T) { } provider, shutdown, err := setupMetricsProvider(ctx, logger, "test-service", cfg) - assert.NoError(t, err) - assert.NotNil(t, provider) - assert.NotNil(t, shutdown) + test.NoError(t, err) + test.NotNil(t, provider) + test.NotNil(t, shutdown) }) T.Run("with host metrics enabled", func(t *testing.T) { @@ -132,9 +132,9 @@ func TestSetupMetricsProvider(T *testing.T) { } provider, shutdown, err := setupMetricsProvider(ctx, logger, "test-service", cfg) - assert.NoError(t, err) - assert.NotNil(t, provider) - assert.NotNil(t, shutdown) + test.NoError(t, err) + test.NotNil(t, provider) + test.NotNil(t, shutdown) }) } @@ -148,9 +148,9 @@ func TestProvideMetricsProvider(T *testing.T) { logger := logging.NewNoopLogger() provider, err := ProvideMetricsProvider(ctx, logger, "test-service", nil) - assert.Nil(t, provider) - assert.Error(t, err) - assert.Equal(t, ErrNilConfig, err) + test.Nil(t, provider) + test.Error(t, err) + test.ErrorIs(t, err, ErrNilConfig) }) T.Run("valid config", func(t *testing.T) { @@ -167,9 +167,10 @@ func TestProvideMetricsProvider(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - assert.NoError(t, err) - assert.NotNil(t, provider) - assert.Implements(t, (*metrics.Provider)(nil), provider) + test.NoError(t, err) + test.NotNil(t, provider) + _, ok := any(provider).(metrics.Provider) + test.True(t, ok) }) } @@ -190,10 +191,10 @@ func TestProviderImpl_MeterProvider(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) meterProvider := provider.MeterProvider() - assert.NotNil(t, meterProvider) + test.NotNil(t, meterProvider) }) } @@ -226,12 +227,13 @@ func TestProviderImpl_NewFloat64Counter(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) counter, err := provider.NewFloat64Counter("test_counter") - assert.NoError(t, err) - assert.NotNil(t, counter) - assert.Implements(t, (*metrics.Float64Counter)(nil), counter) + test.NoError(t, err) + test.NotNil(t, counter) + _, ok := any(counter).(metrics.Float64Counter) + test.True(t, ok) }) } @@ -252,12 +254,13 @@ func TestProviderImpl_NewFloat64Gauge(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) gauge, err := provider.NewFloat64Gauge("test_gauge") - assert.NoError(t, err) - assert.NotNil(t, gauge) - assert.Implements(t, (*metrics.Float64Gauge)(nil), gauge) + test.NoError(t, err) + test.NotNil(t, gauge) + _, ok := any(gauge).(metrics.Float64Gauge) + test.True(t, ok) }) } @@ -278,12 +281,13 @@ func TestProviderImpl_NewFloat64UpDownCounter(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) counter, err := provider.NewFloat64UpDownCounter("test_updown_counter") - assert.NoError(t, err) - assert.NotNil(t, counter) - assert.Implements(t, (*metrics.Float64UpDownCounter)(nil), counter) + test.NoError(t, err) + test.NotNil(t, counter) + _, ok := any(counter).(metrics.Float64UpDownCounter) + test.True(t, ok) }) } @@ -304,12 +308,13 @@ func TestProviderImpl_NewFloat64Histogram(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) histogram, err := provider.NewFloat64Histogram("test_histogram") - assert.NoError(t, err) - assert.NotNil(t, histogram) - assert.Implements(t, (*metrics.Float64Histogram)(nil), histogram) + test.NoError(t, err) + test.NotNil(t, histogram) + _, ok := any(histogram).(metrics.Float64Histogram) + test.True(t, ok) }) } @@ -330,12 +335,13 @@ func TestProviderImpl_NewInt64Counter(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) counter, err := provider.NewInt64Counter("test_counter") - assert.NoError(t, err) - assert.NotNil(t, counter) - assert.Implements(t, (*metrics.Int64Counter)(nil), counter) + test.NoError(t, err) + test.NotNil(t, counter) + _, ok := any(counter).(metrics.Int64Counter) + test.True(t, ok) }) } @@ -356,12 +362,13 @@ func TestProviderImpl_NewInt64Gauge(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) gauge, err := provider.NewInt64Gauge("test_gauge") - assert.NoError(t, err) - assert.NotNil(t, gauge) - assert.Implements(t, (*metrics.Int64Gauge)(nil), gauge) + test.NoError(t, err) + test.NotNil(t, gauge) + _, ok := any(gauge).(metrics.Int64Gauge) + test.True(t, ok) }) } @@ -382,12 +389,13 @@ func TestProviderImpl_NewInt64UpDownCounter(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) counter, err := provider.NewInt64UpDownCounter("test_updown_counter") - assert.NoError(t, err) - assert.NotNil(t, counter) - assert.Implements(t, (*metrics.Int64UpDownCounter)(nil), counter) + test.NoError(t, err) + test.NotNil(t, counter) + _, ok := any(counter).(metrics.Int64UpDownCounter) + test.True(t, ok) }) } @@ -408,12 +416,13 @@ func TestProviderImpl_NewInt64Histogram(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) histogram, err := provider.NewInt64Histogram("test_histogram") - assert.NoError(t, err) - assert.NotNil(t, histogram) - assert.Implements(t, (*metrics.Int64Histogram)(nil), histogram) + test.NoError(t, err) + test.NotNil(t, histogram) + _, ok := any(histogram).(metrics.Int64Histogram) + test.True(t, ok) }) } @@ -434,12 +443,12 @@ func TestProviderImpl_ServiceNamePrefixing(T *testing.T) { } provider, err := ProvideMetricsProvider(ctx, logger, "my-service", cfg) - require.NoError(t, err) + must.NoError(t, err) // Test that metrics are created with service name prefix counter, err := provider.NewInt64Counter("test_metric") - assert.NoError(t, err) - assert.NotNil(t, counter) + test.NoError(t, err) + test.NotNil(t, counter) // The actual metric name should be "my-service.test_metric" but we can't easily test that // without accessing internal OpenTelemetry state, so we just verify the metric was created diff --git a/observability/metrics/testing.go b/observability/metrics/testing.go index 5b61200..6825106 100644 --- a/observability/metrics/testing.go +++ b/observability/metrics/testing.go @@ -3,7 +3,7 @@ package metrics import ( "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/metric" ) @@ -12,7 +12,7 @@ func Int64CounterForTest(t *testing.T, name string) metric.Int64Counter { t.Helper() x, err := otel.Meter("testing").Int64Counter(name) - require.NoError(t, err) + must.NoError(t, err) return x } diff --git a/observability/profiling/config/config_test.go b/observability/profiling/config/config_test.go index 6189441..a340811 100644 --- a/observability/profiling/config/config_test.go +++ b/observability/profiling/config/config_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/profiling/pprof" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/profiling/pyroscope" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -18,7 +18,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { T.Run("valid empty provider", func(t *testing.T) { t.Parallel() c := &Config{Provider: ""} - assert.NoError(t, c.ValidateWithContext(t.Context())) + test.NoError(t, c.ValidateWithContext(t.Context())) }) T.Run("valid pprof provider", func(t *testing.T) { @@ -27,7 +27,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderPprof, Pprof: &pprof.Config{Port: 6060}, } - assert.NoError(t, c.ValidateWithContext(t.Context())) + test.NoError(t, c.ValidateWithContext(t.Context())) }) T.Run("valid pyroscope provider", func(t *testing.T) { @@ -39,19 +39,19 @@ func TestConfig_ValidateWithContext(T *testing.T) { UploadRate: 1, }, } - assert.NoError(t, c.ValidateWithContext(t.Context())) + test.NoError(t, c.ValidateWithContext(t.Context())) }) T.Run("invalid provider string", func(t *testing.T) { t.Parallel() c := &Config{Provider: "invalid"} - assert.Error(t, c.ValidateWithContext(t.Context())) + test.Error(t, c.ValidateWithContext(t.Context())) }) T.Run("pyroscope provider without config", func(t *testing.T) { t.Parallel() c := &Config{Provider: ProviderPyroscope} - assert.Error(t, c.ValidateWithContext(t.Context())) + test.Error(t, c.ValidateWithContext(t.Context())) }) T.Run("pprof config present with empty provider is invalid", func(t *testing.T) { @@ -60,7 +60,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "", Pprof: &pprof.Config{Port: 6060}, } - assert.Error(t, c.ValidateWithContext(t.Context())) + test.Error(t, c.ValidateWithContext(t.Context())) }) T.Run("pyroscope config present with pprof provider is invalid", func(t *testing.T) { @@ -72,7 +72,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { UploadRate: 1, }, } - assert.Error(t, c.ValidateWithContext(t.Context())) + test.Error(t, c.ValidateWithContext(t.Context())) }) } @@ -85,25 +85,25 @@ func TestConfig_ProvideProfilingProvider(T *testing.T) { t.Parallel() c := &Config{Provider: ""} p, err := c.ProvideProfilingProvider(t.Context(), logger) - require.NoError(t, err) - assert.NotNil(t, p) + must.NoError(t, err) + test.NotNil(t, p) }) T.Run("unknown provider returns noop", func(t *testing.T) { t.Parallel() c := &Config{Provider: "unknown"} p, err := c.ProvideProfilingProvider(t.Context(), logger) - require.NoError(t, err) - assert.NotNil(t, p) + must.NoError(t, err) + test.NotNil(t, p) }) T.Run("pprof with nil config uses defaults", func(t *testing.T) { t.Parallel() c := &Config{Provider: ProviderPprof} p, err := c.ProvideProfilingProvider(t.Context(), logger) - require.NoError(t, err) - assert.NotNil(t, p) - require.NoError(t, p.Shutdown(t.Context())) + must.NoError(t, err) + test.NotNil(t, p) + must.NoError(t, p.Shutdown(t.Context())) }) T.Run("pprof with config", func(t *testing.T) { @@ -113,17 +113,17 @@ func TestConfig_ProvideProfilingProvider(T *testing.T) { Pprof: &pprof.Config{Port: 16060}, } p, err := c.ProvideProfilingProvider(t.Context(), logger) - require.NoError(t, err) - assert.NotNil(t, p) - require.NoError(t, p.Shutdown(t.Context())) + must.NoError(t, err) + test.NotNil(t, p) + must.NoError(t, p.Shutdown(t.Context())) }) T.Run("pyroscope with nil config returns noop", func(t *testing.T) { t.Parallel() c := &Config{Provider: ProviderPyroscope} p, err := c.ProvideProfilingProvider(t.Context(), logger) - require.NoError(t, err) - assert.NotNil(t, p) + must.NoError(t, err) + test.NotNil(t, p) }) T.Run("pyroscope with config sets default upload rate", func(t *testing.T) { @@ -136,10 +136,10 @@ func TestConfig_ProvideProfilingProvider(T *testing.T) { }, } p, err := c.ProvideProfilingProvider(t.Context(), logger) - require.NoError(t, err) - assert.NotNil(t, p) - assert.Equal(t, 15*time.Second, c.Pyroscope.UploadRate) - require.NoError(t, p.Shutdown(t.Context())) + must.NoError(t, err) + test.NotNil(t, p) + test.EqOp(t, 15*time.Second, c.Pyroscope.UploadRate) + must.NoError(t, p.Shutdown(t.Context())) }) T.Run("pyroscope with non-default upload rate", func(t *testing.T) { @@ -153,9 +153,9 @@ func TestConfig_ProvideProfilingProvider(T *testing.T) { }, } p, err := c.ProvideProfilingProvider(t.Context(), logger) - require.NoError(t, err) - assert.NotNil(t, p) - assert.Equal(t, 5*time.Second, c.Pyroscope.UploadRate) - require.NoError(t, p.Shutdown(t.Context())) + must.NoError(t, err) + test.NotNil(t, p) + test.EqOp(t, 5*time.Second, c.Pyroscope.UploadRate) + must.NoError(t, p.Shutdown(t.Context())) }) } diff --git a/observability/profiling/config/do_test.go b/observability/profiling/config/do_test.go index 137b76f..31a51c1 100644 --- a/observability/profiling/config/do_test.go +++ b/observability/profiling/config/do_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/profiling" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterProfilingProvider(T *testing.T) { @@ -26,7 +26,7 @@ func TestRegisterProfilingProvider(T *testing.T) { RegisterProfilingProvider(i) p, err := do.Invoke[profiling.Provider](i) - require.NoError(t, err) - assert.NotNil(t, p) + must.NoError(t, err) + test.NotNil(t, p) }) } diff --git a/observability/profiling/noop_test.go b/observability/profiling/noop_test.go index 0214d09..b5dd827 100644 --- a/observability/profiling/noop_test.go +++ b/observability/profiling/noop_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestNewNoopProvider(T *testing.T) { @@ -13,18 +13,18 @@ func TestNewNoopProvider(T *testing.T) { T.Run("returns non-nil provider", func(t *testing.T) { t.Parallel() p := NewNoopProvider() - assert.NotNil(t, p) + test.NotNil(t, p) }) T.Run("Start returns nil", func(t *testing.T) { t.Parallel() p := NewNoopProvider() - assert.NoError(t, p.Start(context.Background())) + test.NoError(t, p.Start(context.Background())) }) T.Run("Shutdown returns nil", func(t *testing.T) { t.Parallel() p := NewNoopProvider() - assert.NoError(t, p.Shutdown(context.Background())) + test.NoError(t, p.Shutdown(context.Background())) }) } diff --git a/observability/profiling/pprof/config_test.go b/observability/profiling/pprof/config_test.go index 0d68a45..b20e70e 100644 --- a/observability/profiling/pprof/config_test.go +++ b/observability/profiling/pprof/config_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -15,12 +15,12 @@ func TestConfig_ValidateWithContext(T *testing.T) { T.Run("valid config", func(t *testing.T) { t.Parallel() c := &Config{Port: 6060} - assert.NoError(t, c.ValidateWithContext(ctx)) + test.NoError(t, c.ValidateWithContext(ctx)) }) T.Run("default port is valid", func(t *testing.T) { t.Parallel() c := &Config{Port: DefaultPort} - assert.NoError(t, c.ValidateWithContext(ctx)) + test.NoError(t, c.ValidateWithContext(ctx)) }) } diff --git a/observability/profiling/pprof/provider_test.go b/observability/profiling/pprof/provider_test.go index 78a4861..907de65 100644 --- a/observability/profiling/pprof/provider_test.go +++ b/observability/profiling/pprof/provider_test.go @@ -6,8 +6,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestProvideProfilingProvider(T *testing.T) { @@ -16,17 +16,17 @@ func TestProvideProfilingProvider(T *testing.T) { T.Run("nil config returns noop", func(t *testing.T) { t.Parallel() p, err := ProvideProfilingProvider(context.Background(), logging.NewNoopLogger(), nil) - require.NoError(t, err) - assert.NotNil(t, p) + must.NoError(t, err) + test.NotNil(t, p) }) T.Run("zero port uses default", func(t *testing.T) { t.Parallel() cfg := &Config{Port: 0} p, err := ProvideProfilingProvider(context.Background(), logging.NewNoopLogger(), cfg) - require.NoError(t, err) - assert.NotNil(t, p) - require.NoError(t, p.Shutdown(context.Background())) + must.NoError(t, err) + test.NotNil(t, p) + must.NoError(t, p.Shutdown(context.Background())) }) T.Run("with mutex and block profiling", func(t *testing.T) { @@ -37,17 +37,17 @@ func TestProvideProfilingProvider(T *testing.T) { EnableBlockProfile: true, } p, err := ProvideProfilingProvider(context.Background(), logging.NewNoopLogger(), cfg) - require.NoError(t, err) - assert.NotNil(t, p) - require.NoError(t, p.Shutdown(context.Background())) + must.NoError(t, err) + test.NotNil(t, p) + must.NoError(t, p.Shutdown(context.Background())) }) T.Run("start and shutdown", func(t *testing.T) { t.Parallel() cfg := &Config{Port: 16062} p, err := ProvideProfilingProvider(context.Background(), logging.NewNoopLogger(), cfg) - require.NoError(t, err) - require.NoError(t, p.Start(context.Background())) - require.NoError(t, p.Shutdown(context.Background())) + must.NoError(t, err) + must.NoError(t, p.Start(context.Background())) + must.NoError(t, p.Shutdown(context.Background())) }) } diff --git a/observability/profiling/pyroscope/config_test.go b/observability/profiling/pyroscope/config_test.go index 3424ac4..05e4e67 100644 --- a/observability/profiling/pyroscope/config_test.go +++ b/observability/profiling/pyroscope/config_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -19,7 +19,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { ServerAddress: "http://localhost:4040", UploadRate: 15 * time.Second, } - assert.NoError(t, c.ValidateWithContext(ctx)) + test.NoError(t, c.ValidateWithContext(ctx)) }) T.Run("missing server address", func(t *testing.T) { @@ -27,7 +27,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { c := &Config{ UploadRate: 15 * time.Second, } - assert.Error(t, c.ValidateWithContext(ctx)) + test.Error(t, c.ValidateWithContext(ctx)) }) T.Run("missing upload rate", func(t *testing.T) { @@ -35,6 +35,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { c := &Config{ ServerAddress: "http://localhost:4040", } - assert.Error(t, c.ValidateWithContext(ctx)) + test.Error(t, c.ValidateWithContext(ctx)) }) } diff --git a/observability/profiling/pyroscope/provider_test.go b/observability/profiling/pyroscope/provider_test.go index 600e854..2f3254d 100644 --- a/observability/profiling/pyroscope/provider_test.go +++ b/observability/profiling/pyroscope/provider_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/profiling" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestProvideProfilingProvider(T *testing.T) { @@ -21,8 +21,8 @@ func TestProvideProfilingProvider(T *testing.T) { logger := logging.NewNoopLogger() p, err := ProvideProfilingProvider(ctx, logger, "test-service", nil) - require.NoError(t, err) - assert.NotNil(t, p) + must.NoError(t, err) + test.NotNil(t, p) }) T.Run("standard", func(t *testing.T) { @@ -36,10 +36,10 @@ func TestProvideProfilingProvider(T *testing.T) { } p, err := ProvideProfilingProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) - require.NotNil(t, p) + must.NoError(t, err) + must.NotNil(t, p) - require.NoError(t, p.Shutdown(ctx)) + must.NoError(t, p.Shutdown(ctx)) }) T.Run("with mutex and block profiles", func(t *testing.T) { @@ -55,10 +55,10 @@ func TestProvideProfilingProvider(T *testing.T) { } p, err := ProvideProfilingProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) - require.NotNil(t, p) + must.NoError(t, err) + must.NotNil(t, p) - require.NoError(t, p.Shutdown(ctx)) + must.NoError(t, p.Shutdown(ctx)) }) T.Run("with tags", func(t *testing.T) { @@ -73,10 +73,10 @@ func TestProvideProfilingProvider(T *testing.T) { } p, err := ProvideProfilingProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) - require.NotNil(t, p) + must.NoError(t, err) + must.NotNil(t, p) - require.NoError(t, p.Shutdown(ctx)) + must.NoError(t, p.Shutdown(ctx)) }) } @@ -94,10 +94,10 @@ func TestProvider_Start(T *testing.T) { } p, err := ProvideProfilingProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) - assert.NoError(t, p.Start(ctx)) - require.NoError(t, p.Shutdown(ctx)) + test.NoError(t, p.Start(ctx)) + must.NoError(t, p.Shutdown(ctx)) }) } @@ -115,9 +115,9 @@ func TestProvider_Shutdown(T *testing.T) { } p, err := ProvideProfilingProvider(ctx, logger, "test-service", cfg) - require.NoError(t, err) + must.NoError(t, err) - assert.NoError(t, p.Shutdown(ctx)) + test.NoError(t, p.Shutdown(ctx)) }) } diff --git a/observability/tracing/caller_test.go b/observability/tracing/caller_test.go index 03f7648..57cfb48 100644 --- a/observability/tracing/caller_test.go +++ b/observability/tracing/caller_test.go @@ -3,7 +3,7 @@ package tracing import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestGetCallerName(T *testing.T) { @@ -12,6 +12,6 @@ func TestGetCallerName(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotEmpty(t, GetCallerName()) + test.NotEq(t, "", GetCallerName()) }) } diff --git a/observability/tracing/cloudtrace/config_test.go b/observability/tracing/cloudtrace/config_test.go index e96c490..f7f7c90 100644 --- a/observability/tracing/cloudtrace/config_test.go +++ b/observability/tracing/cloudtrace/config_test.go @@ -3,7 +3,7 @@ package cloudtrace import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestCloudTraceConfig_ValidateWithContext(T *testing.T) { @@ -17,6 +17,6 @@ func TestCloudTraceConfig_ValidateWithContext(T *testing.T) { ProjectID: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/observability/tracing/config/config_test.go b/observability/tracing/config/config_test.go index 2957a00..9606bac 100644 --- a/observability/tracing/config/config_test.go +++ b/observability/tracing/config/config_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing/cloudtrace" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing/oteltrace" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ProvideTracerProvider(T *testing.T) { @@ -26,8 +26,8 @@ func TestConfig_ProvideTracerProvider(T *testing.T) { logging.NewNoopLogger(), ) - assert.NoError(t, err) - assert.NotNil(t, tracerProvider) + test.NoError(t, err) + test.NotNil(t, tracerProvider) }) T.Run("with otel provider", func(t *testing.T) { @@ -48,8 +48,8 @@ func TestConfig_ProvideTracerProvider(T *testing.T) { logging.NewNoopLogger(), ) - assert.NoError(t, err) - assert.NotNil(t, tracerProvider) + test.NoError(t, err) + test.NotNil(t, tracerProvider) }) } @@ -58,7 +58,7 @@ func TestConfig_ProvideTracerProvider(T *testing.T) { func TestConfig_ProvideTracerProvider_CloudTrace(t *testing.T) { dir := t.TempDir() credPath := filepath.Join(dir, "creds.json") - require.NoError(t, os.WriteFile(credPath, []byte(`{"type":"authorized_user","client_id":"x","client_secret":"y","refresh_token":"z"}`), 0o600)) + must.NoError(t, os.WriteFile(credPath, []byte(`{"type":"authorized_user","client_id":"x","client_secret":"y","refresh_token":"z"}`), 0o600)) t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", credPath) cfg := &Config{ @@ -75,8 +75,8 @@ func TestConfig_ProvideTracerProvider_CloudTrace(t *testing.T) { logging.NewNoopLogger(), ) - require.NoError(t, err) - assert.NotNil(t, tracerProvider) + must.NoError(t, err) + test.NotNil(t, tracerProvider) } // TestConfig_ProvideTracerProvider_CloudTraceError covers the cloudtrace error branch. @@ -100,8 +100,8 @@ func TestConfig_ProvideTracerProvider_CloudTraceError(t *testing.T) { logging.NewNoopLogger(), ) - assert.Error(t, err) - assert.Nil(t, tracerProvider) + test.Error(t, err) + test.Nil(t, tracerProvider) } // TestConfig_ProvideTracerProvider_OtelError covers the otelgrpc error branch. @@ -122,8 +122,8 @@ func TestConfig_ProvideTracerProvider_OtelError(T *testing.T) { } tracerProvider, err := cfg.ProvideTracerProvider(t.Context(), logging.NewNoopLogger()) - assert.Error(t, err) - assert.Nil(t, tracerProvider) + test.Error(t, err) + test.Nil(t, tracerProvider) }) } @@ -144,8 +144,8 @@ func TestConfig_ProvideTracer_Error(T *testing.T) { } tracer, err := cfg.ProvideTracer(t.Context(), logging.NewNoopLogger(), t.Name()) - assert.Error(t, err) - assert.Nil(t, tracer) + test.Error(t, err) + test.Nil(t, tracer) }) } @@ -158,8 +158,8 @@ func TestConfig_ProvideTracer(T *testing.T) { cfg := &Config{} tracer, err := cfg.ProvideTracer(t.Context(), logging.NewNoopLogger(), t.Name()) - assert.NoError(t, err) - assert.NotNil(t, tracer) + test.NoError(t, err) + test.NotNil(t, tracer) }) T.Run("with otel provider", func(t *testing.T) { @@ -176,8 +176,8 @@ func TestConfig_ProvideTracer(T *testing.T) { } tracer, err := cfg.ProvideTracer(t.Context(), logging.NewNoopLogger(), t.Name()) - assert.NoError(t, err) - assert.NotNil(t, tracer) + test.NoError(t, err) + test.NotNil(t, tracer) }) } @@ -196,7 +196,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("with cloudtrace provider", func(t *testing.T) { @@ -211,7 +211,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing required service name", func(t *testing.T) { @@ -225,7 +225,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("invalid provider", func(t *testing.T) { @@ -237,7 +237,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { SpanCollectionProbability: 1, } - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) } @@ -250,7 +250,7 @@ func TestProvideTracerProvider(T *testing.T) { cfg := &Config{} tracerProvider, err := ProvideTracerProvider(t.Context(), cfg, logging.NewNoopLogger()) - assert.NoError(t, err) - assert.NotNil(t, tracerProvider) + test.NoError(t, err) + test.NotNil(t, tracerProvider) }) } diff --git a/observability/tracing/config/do_test.go b/observability/tracing/config/do_test.go index b8b558f..cd81a61 100644 --- a/observability/tracing/config/do_test.go +++ b/observability/tracing/config/do_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterTracerProvider(T *testing.T) { @@ -26,7 +26,7 @@ func TestRegisterTracerProvider(T *testing.T) { RegisterTracerProvider(i) tp, err := do.Invoke[tracing.TracerProvider](i) - require.NoError(t, err) - assert.NotNil(t, tp) + must.NoError(t, err) + test.NotNil(t, tp) }) } diff --git a/observability/tracing/instrumentedsql_test.go b/observability/tracing/instrumentedsql_test.go index 074deed..13b0648 100644 --- a/observability/tracing/instrumentedsql_test.go +++ b/observability/tracing/instrumentedsql_test.go @@ -3,7 +3,7 @@ package tracing import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestNewInstrumentedSQLTracer(T *testing.T) { @@ -12,7 +12,7 @@ func TestNewInstrumentedSQLTracer(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewInstrumentedSQLTracer(NewNoopTracerProvider(), t.Name())) + test.NotNil(t, NewInstrumentedSQLTracer(NewNoopTracerProvider(), t.Name())) }) } @@ -25,6 +25,6 @@ func Test_instrumentedSQLTracerWrapper_GetSpan(T *testing.T) { ctx := t.Context() w := NewInstrumentedSQLTracer(NewNoopTracerProvider(), t.Name()) - assert.NotNil(t, w.GetSpan(ctx)) + test.NotNil(t, w.GetSpan(ctx)) }) } diff --git a/observability/tracing/oteltrace/config_test.go b/observability/tracing/oteltrace/config_test.go index 24ef333..5345ba2 100644 --- a/observability/tracing/oteltrace/config_test.go +++ b/observability/tracing/oteltrace/config_test.go @@ -3,7 +3,7 @@ package oteltrace import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -17,6 +17,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { CollectorEndpoint: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/observability/tracing/oteltrace/tracer_test.go b/observability/tracing/oteltrace/tracer_test.go index 7a62558..339b070 100644 --- a/observability/tracing/oteltrace/tracer_test.go +++ b/observability/tracing/oteltrace/tracer_test.go @@ -6,7 +6,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func Test_tracingErrorHandler_Handle(T *testing.T) { @@ -31,7 +31,7 @@ func TestConfig_SetupOtelHTTP(T *testing.T) { } actual, err := SetupOtelGRPC(ctx, t.Name(), 0, cfg) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) }) } diff --git a/observability/tracing/span_attachers_test.go b/observability/tracing/span_attachers_test.go index b121596..dc15cd7 100644 --- a/observability/tracing/span_attachers_test.go +++ b/observability/tracing/span_attachers_test.go @@ -7,7 +7,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/database/filtering" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) // mockServicePermissionChecker implements ServicePermissionChecker for tests. @@ -63,7 +63,7 @@ func TestAttachRequestToSpan(T *testing.T) { ctx, span := StartSpan(t.Context()) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody) req.Header.Set(t.Name(), "blah") - require.NoError(t, err) + must.NoError(t, err) AttachRequestToSpan(span, req) }) diff --git a/observability/tracing/span_manager_test.go b/observability/tracing/span_manager_test.go index 3b337fa..c22bc4c 100644 --- a/observability/tracing/span_manager_test.go +++ b/observability/tracing/span_manager_test.go @@ -3,7 +3,7 @@ package tracing import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestNewTracer(T *testing.T) { @@ -12,7 +12,7 @@ func TestNewTracer(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewTracerForTest(t.Name())) + test.NotNil(t, NewTracerForTest(t.Name())) }) } @@ -22,13 +22,13 @@ func TestNewNamedTracer(T *testing.T) { T.Run("with nil provider", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNamedTracer(nil, t.Name())) + test.NotNil(t, NewNamedTracer(nil, t.Name())) }) T.Run("with valid provider", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewNamedTracer(NewNoopTracerProvider(), t.Name())) + test.NotNil(t, NewNamedTracer(NewNoopTracerProvider(), t.Name())) }) } diff --git a/observability/tracing/spans_test.go b/observability/tracing/spans_test.go index f3578ac..8dc7631 100644 --- a/observability/tracing/spans_test.go +++ b/observability/tracing/spans_test.go @@ -5,7 +5,7 @@ import ( "net/url" "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestStartCustomSpan(T *testing.T) { @@ -53,7 +53,7 @@ func TestFormatSpan(T *testing.T) { t.Parallel() u, err := url.ParseRequestURI("https://whatever.whocares.gov") - require.NoError(t, err) + must.NoError(t, err) FormatSpan(t.Name(), &http.Request{URL: u}) }) diff --git a/observability/utils/otel_test.go b/observability/utils/otel_test.go index 622edf4..2e96ecd 100644 --- a/observability/utils/otel_test.go +++ b/observability/utils/otel_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestMustOtelResource(T *testing.T) { @@ -13,12 +13,12 @@ func TestMustOtelResource(T *testing.T) { T.Run("with service name", func(t *testing.T) { t.Parallel() res := MustOtelResource(context.Background(), "test-service") - assert.NotNil(t, res) + test.NotNil(t, res) }) T.Run("without service name", func(t *testing.T) { t.Parallel() res := MustOtelResource(context.Background(), "") - assert.NotNil(t, res) + test.NotNil(t, res) }) } diff --git a/panicking/panicker_test.go b/panicking/panicker_test.go index 8e06df6..b0c2bf6 100644 --- a/panicking/panicker_test.go +++ b/panicking/panicker_test.go @@ -3,7 +3,7 @@ package panicking import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestNewProductionPanicker(T *testing.T) { @@ -12,7 +12,7 @@ func TestNewProductionPanicker(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewProductionPanicker()) + test.NotNil(t, NewProductionPanicker()) }) } @@ -25,7 +25,7 @@ func Test_stdLibPanicker_Panic(T *testing.T) { p := NewProductionPanicker() defer func() { - assert.NotNil(t, recover(), "expected panic to occur") + test.NotNil(t, recover(), test.Sprint("expected panic to occur")) }() p.Panic("blah") @@ -41,7 +41,7 @@ func Test_stdLibPanicker_Panicf(T *testing.T) { p := NewProductionPanicker() defer func() { - assert.NotNil(t, recover(), "expected panic to occur") + test.NotNil(t, recover(), test.Sprint("expected panic to occur")) }() p.Panicf("blah") diff --git a/pointer/pointers_test.go b/pointer/pointers_test.go index f6d975b..7abcd86 100644 --- a/pointer/pointers_test.go +++ b/pointer/pointers_test.go @@ -3,8 +3,8 @@ package pointer import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestTo(T *testing.T) { @@ -16,8 +16,8 @@ func TestTo(T *testing.T) { expected := "things" actual := To(expected) - require.NotNil(t, actual) - assert.Equal(t, expected, *actual) + must.NotNil(t, actual) + test.EqOp(t, expected, *actual) }) T.Run("with int", func(t *testing.T) { @@ -26,8 +26,8 @@ func TestTo(T *testing.T) { expected := 42 actual := To(expected) - require.NotNil(t, actual) - assert.Equal(t, expected, *actual) + must.NotNil(t, actual) + test.EqOp(t, expected, *actual) }) T.Run("with zero value", func(t *testing.T) { @@ -35,8 +35,8 @@ func TestTo(T *testing.T) { actual := To(0) - require.NotNil(t, actual) - assert.Equal(t, 0, *actual) + must.NotNil(t, actual) + test.EqOp(t, 0, *actual) }) T.Run("with struct", func(t *testing.T) { @@ -46,8 +46,8 @@ func TestTo(T *testing.T) { expected := example{Name: "test"} actual := To(expected) - require.NotNil(t, actual) - assert.Equal(t, expected, *actual) + must.NotNil(t, actual) + test.EqOp(t, expected, *actual) }) } @@ -60,10 +60,10 @@ func TestToSlice(T *testing.T) { input := []string{"a", "b", "c"} actual := ToSlice(input) - require.Len(t, actual, 3) - assert.Equal(t, "a", *actual[0]) - assert.Equal(t, "b", *actual[1]) - assert.Equal(t, "c", *actual[2]) + must.SliceLen(t, 3, actual) + test.EqOp(t, "a", *actual[0]) + test.EqOp(t, "b", *actual[1]) + test.EqOp(t, "c", *actual[2]) }) T.Run("with int slice", func(t *testing.T) { @@ -72,10 +72,10 @@ func TestToSlice(T *testing.T) { input := []int{1, 2, 3} actual := ToSlice(input) - require.Len(t, actual, 3) - assert.Equal(t, 1, *actual[0]) - assert.Equal(t, 2, *actual[1]) - assert.Equal(t, 3, *actual[2]) + must.SliceLen(t, 3, actual) + test.EqOp(t, 1, *actual[0]) + test.EqOp(t, 2, *actual[1]) + test.EqOp(t, 3, *actual[2]) }) T.Run("with nil slice", func(t *testing.T) { @@ -83,8 +83,8 @@ func TestToSlice(T *testing.T) { actual := ToSlice[string](nil) - assert.NotNil(t, actual) - assert.Empty(t, actual) + test.NotNil(t, actual) + test.SliceEmpty(t, actual) }) T.Run("with empty slice", func(t *testing.T) { @@ -92,8 +92,8 @@ func TestToSlice(T *testing.T) { actual := ToSlice([]string{}) - assert.NotNil(t, actual) - assert.Empty(t, actual) + test.NotNil(t, actual) + test.SliceEmpty(t, actual) }) } @@ -106,7 +106,7 @@ func TestDereference(T *testing.T) { rawExpected := "things" actual := Dereference(&rawExpected) - assert.Equal(t, rawExpected, actual) + test.EqOp(t, rawExpected, actual) }) T.Run("with int pointer", func(t *testing.T) { @@ -115,7 +115,7 @@ func TestDereference(T *testing.T) { expected := 42 actual := Dereference(&expected) - assert.Equal(t, 42, actual) + test.EqOp(t, 42, actual) }) T.Run("with nil string pointer", func(t *testing.T) { @@ -123,7 +123,7 @@ func TestDereference(T *testing.T) { actual := Dereference[string](nil) - assert.Equal(t, "", actual) + test.EqOp(t, "", actual) }) T.Run("with nil int pointer", func(t *testing.T) { @@ -131,7 +131,7 @@ func TestDereference(T *testing.T) { actual := Dereference[int](nil) - assert.Equal(t, 0, actual) + test.EqOp(t, 0, actual) }) T.Run("with nil bool pointer", func(t *testing.T) { @@ -139,7 +139,7 @@ func TestDereference(T *testing.T) { actual := Dereference[bool](nil) - assert.False(t, actual) + test.False(t, actual) }) } @@ -153,10 +153,10 @@ func TestDereferenceSlice(T *testing.T) { input := []*string{&a, &b, &c} actual := DereferenceSlice(input) - require.Len(t, actual, 3) - assert.Equal(t, "a", actual[0]) - assert.Equal(t, "b", actual[1]) - assert.Equal(t, "c", actual[2]) + must.SliceLen(t, 3, actual) + test.EqOp(t, "a", actual[0]) + test.EqOp(t, "b", actual[1]) + test.EqOp(t, "c", actual[2]) }) T.Run("with int pointer slice", func(t *testing.T) { @@ -166,9 +166,9 @@ func TestDereferenceSlice(T *testing.T) { input := []*int{&a, &b} actual := DereferenceSlice(input) - require.Len(t, actual, 2) - assert.Equal(t, 1, actual[0]) - assert.Equal(t, 2, actual[1]) + must.SliceLen(t, 2, actual) + test.EqOp(t, 1, actual[0]) + test.EqOp(t, 2, actual[1]) }) T.Run("with nil slice", func(t *testing.T) { @@ -176,8 +176,8 @@ func TestDereferenceSlice(T *testing.T) { actual := DereferenceSlice[string](nil) - assert.NotNil(t, actual) - assert.Empty(t, actual) + test.NotNil(t, actual) + test.SliceEmpty(t, actual) }) T.Run("with empty slice", func(t *testing.T) { @@ -185,7 +185,7 @@ func TestDereferenceSlice(T *testing.T) { actual := DereferenceSlice([]*string{}) - assert.NotNil(t, actual) - assert.Empty(t, actual) + test.NotNil(t, actual) + test.SliceEmpty(t, actual) }) } diff --git a/qrcodes/do_test.go b/qrcodes/do_test.go index 5c722f1..cb7e6b4 100644 --- a/qrcodes/do_test.go +++ b/qrcodes/do_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterBuilder(T *testing.T) { @@ -25,7 +25,7 @@ func TestRegisterBuilder(T *testing.T) { RegisterBuilder(i) b, err := do.Invoke[Builder](i) - require.NoError(t, err) - assert.NotNil(t, b) + must.NoError(t, err) + test.NotNil(t, b) }) } diff --git a/qrcodes/qrcodes_test.go b/qrcodes/qrcodes_test.go index 7e3b3e0..6ffbc02 100644 --- a/qrcodes/qrcodes_test.go +++ b/qrcodes/qrcodes_test.go @@ -7,8 +7,8 @@ import ( "testing" "github.com/boombuler/barcode" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestNewBuilder(T *testing.T) { @@ -18,7 +18,7 @@ func TestNewBuilder(T *testing.T) { t.Parallel() b := NewBuilder("test-issuer", nil, nil) - assert.NotNil(t, b) + test.NotNil(t, b) }) } @@ -32,8 +32,8 @@ func Test_builder_BuildQRCode(T *testing.T) { b := NewBuilder("test-issuer", nil, nil) actual, err := b.BuildQRCode(ctx, "username", "two-factor-secret") - require.NoError(t, err) - assert.NotEmpty(t, actual) + must.NoError(t, err) + test.NotEq(t, "", actual) }) T.Run("with content exceeding QR capacity", func(t *testing.T) { @@ -44,8 +44,8 @@ func Test_builder_BuildQRCode(T *testing.T) { // A username longer than the maximum QR code capacity forces qr.Encode to fail. actual, err := b.BuildQRCode(ctx, strings.Repeat("a", 4000), "two-factor-secret") - assert.Empty(t, actual) - assert.Error(t, err) + test.EqOp(t, "", actual) + test.Error(t, err) }) T.Run("with scale error", func(t *testing.T) { @@ -58,8 +58,8 @@ func Test_builder_BuildQRCode(T *testing.T) { } actual, err := b.BuildQRCode(ctx, "username", "two-factor-secret") - assert.Empty(t, actual) - assert.Error(t, err) + test.EqOp(t, "", actual) + test.Error(t, err) }) T.Run("with png encode error", func(t *testing.T) { @@ -72,7 +72,7 @@ func Test_builder_BuildQRCode(T *testing.T) { } actual, err := b.BuildQRCode(ctx, "username", "two-factor-secret") - assert.Empty(t, actual) - assert.Error(t, err) + test.EqOp(t, "", actual) + test.Error(t, err) }) } diff --git a/random/do_test.go b/random/do_test.go index b8b2d38..8a6a73c 100644 --- a/random/do_test.go +++ b/random/do_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterGenerator(T *testing.T) { @@ -24,7 +24,7 @@ func TestRegisterGenerator(T *testing.T) { RegisterGenerator(i) g, err := do.Invoke[Generator](i) - require.NoError(t, err) - assert.NotNil(t, g) + must.NoError(t, err) + test.NotNil(t, g) }) } diff --git a/random/random_test.go b/random/random_test.go index 310844c..8f11896 100644 --- a/random/random_test.go +++ b/random/random_test.go @@ -6,8 +6,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type erroneousReader struct{} @@ -25,8 +25,8 @@ func TestGenerateBase32EncodedString(T *testing.T) { ctx := t.Context() actual, err := GenerateBase32EncodedString(ctx, 32) - assert.NoError(t, err) - assert.NotEmpty(t, actual) + test.NoError(t, err) + test.NotEq(t, "", actual) }) } @@ -39,8 +39,8 @@ func TestGenerateBase64EncodedString(T *testing.T) { ctx := t.Context() actual, err := GenerateBase64EncodedString(ctx, 32) - assert.NoError(t, err) - assert.NotEmpty(t, actual) + test.NoError(t, err) + test.NotEq(t, "", actual) }) } @@ -53,8 +53,8 @@ func TestGenerateRawBytes(T *testing.T) { ctx := t.Context() actual, err := GenerateRawBytes(ctx, 32) - assert.NoError(t, err) - assert.NotEmpty(t, actual) + test.NoError(t, err) + test.SliceNotEmpty(t, actual) }) } @@ -70,9 +70,9 @@ func TestStandardSecretGenerator_GenerateBase32EncodedString(T *testing.T) { s := NewGenerator(nil, tracing.NewNoopTracerProvider()) value, err := s.GenerateBase32EncodedString(ctx, exampleLength) - assert.NotEmpty(t, value) - assert.Greater(t, len(value), exampleLength) - assert.NoError(t, err) + test.NotEq(t, "", value) + test.Greater(t, exampleLength, len(value)) + test.NoError(t, err) }) T.Run("with error reading from secure PRNG", func(t *testing.T) { @@ -82,12 +82,12 @@ func TestStandardSecretGenerator_GenerateBase32EncodedString(T *testing.T) { exampleLength := 123 s, ok := NewGenerator(nil, tracing.NewNoopTracerProvider()).(*standardGenerator) - require.True(t, ok) + must.True(t, ok) s.randReader = &erroneousReader{} value, err := s.GenerateBase32EncodedString(ctx, exampleLength) - assert.Empty(t, value) - assert.Error(t, err) + test.EqOp(t, "", value) + test.Error(t, err) }) } @@ -103,9 +103,9 @@ func TestStandardSecretGenerator_GenerateBase64EncodedString(T *testing.T) { s := NewGenerator(nil, tracing.NewNoopTracerProvider()) value, err := s.GenerateBase64EncodedString(ctx, exampleLength) - assert.NotEmpty(t, value) - assert.Greater(t, len(value), exampleLength) - assert.NoError(t, err) + test.NotEq(t, "", value) + test.Greater(t, exampleLength, len(value)) + test.NoError(t, err) }) T.Run("with error reading from secure PRNG", func(t *testing.T) { @@ -115,12 +115,12 @@ func TestStandardSecretGenerator_GenerateBase64EncodedString(T *testing.T) { exampleLength := 123 s, ok := NewGenerator(nil, tracing.NewNoopTracerProvider()).(*standardGenerator) - require.True(t, ok) + must.True(t, ok) s.randReader = &erroneousReader{} value, err := s.GenerateBase64EncodedString(ctx, exampleLength) - assert.Empty(t, value) - assert.Error(t, err) + test.EqOp(t, "", value) + test.Error(t, err) }) } @@ -136,9 +136,9 @@ func TestStandardSecretGenerator_GenerateRawBytes(T *testing.T) { s := NewGenerator(nil, tracing.NewNoopTracerProvider()) value, err := s.GenerateRawBytes(ctx, exampleLength) - assert.NotEmpty(t, value) - assert.Equal(t, len(value), exampleLength) - assert.NoError(t, err) + test.SliceNotEmpty(t, value) + test.EqOp(t, exampleLength, len(value)) + test.NoError(t, err) }) T.Run("with error reading from secure PRNG", func(t *testing.T) { @@ -148,12 +148,12 @@ func TestStandardSecretGenerator_GenerateRawBytes(T *testing.T) { exampleLength := 123 s, ok := NewGenerator(nil, tracing.NewNoopTracerProvider()).(*standardGenerator) - require.True(t, ok) + must.True(t, ok) s.randReader = &erroneousReader{} value, err := s.GenerateRawBytes(ctx, exampleLength) - assert.Empty(t, value) - assert.Error(t, err) + test.SliceEmpty(t, value) + test.Error(t, err) }) } @@ -165,7 +165,7 @@ func TestMustGenerateRawBytes(T *testing.T) { ctx := t.Context() result := MustGenerateRawBytes(ctx, 32) - assert.NotEmpty(t, result) + test.SliceNotEmpty(t, result) }) } @@ -177,7 +177,7 @@ func TestGenerateHexEncodedString(T *testing.T) { ctx := t.Context() result, err := GenerateHexEncodedString(ctx, 32) - assert.NoError(t, err) - assert.NotEmpty(t, result) + test.NoError(t, err) + test.NotEq(t, "", result) }) } diff --git a/random/slices_test.go b/random/slices_test.go index e3ada15..55b36a8 100644 --- a/random/slices_test.go +++ b/random/slices_test.go @@ -4,7 +4,7 @@ import ( "slices" "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestElement(T *testing.T) { @@ -64,6 +64,6 @@ func TestElement(T *testing.T) { "signal.", } - assert.True(t, slices.Contains(exampleArray, Element(exampleArray))) + test.True(t, slices.Contains(exampleArray, Element(exampleArray))) }) } diff --git a/ratelimiting/config/config_test.go b/ratelimiting/config/config_test.go index 847c7c6..af7cc2e 100644 --- a/ratelimiting/config/config_test.go +++ b/ratelimiting/config/config_test.go @@ -6,8 +6,8 @@ import ( redisrl "github.com/verygoodsoftwarenotvirus/platform/v5/ratelimiting/redis" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_EnsureDefaults(T *testing.T) { @@ -19,8 +19,8 @@ func TestConfig_EnsureDefaults(T *testing.T) { cfg := &Config{} cfg.EnsureDefaults() - assert.Equal(t, 10.0, cfg.RequestsPerSec) - assert.Equal(t, 20, cfg.BurstSize) + test.EqOp(t, 10.0, cfg.RequestsPerSec) + test.EqOp(t, 20, cfg.BurstSize) }) T.Run("preserves non-zero values", func(t *testing.T) { @@ -32,8 +32,8 @@ func TestConfig_EnsureDefaults(T *testing.T) { } cfg.EnsureDefaults() - assert.Equal(t, 5.0, cfg.RequestsPerSec) - assert.Equal(t, 10, cfg.BurstSize) + test.EqOp(t, 5.0, cfg.RequestsPerSec) + test.EqOp(t, 10, cfg.BurstSize) }) } @@ -45,12 +45,12 @@ func TestConfig_ProvideRateLimiter(T *testing.T) { var cfg *Config limiter, err := cfg.ProvideRateLimiter(nil) - require.NoError(t, err) - require.NotNil(t, limiter) + must.NoError(t, err) + must.NotNil(t, limiter) allowed, err := limiter.Allow(context.Background(), "x") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) }) T.Run("empty provider returns noop", func(t *testing.T) { @@ -58,12 +58,12 @@ func TestConfig_ProvideRateLimiter(T *testing.T) { cfg := &Config{Provider: ""} limiter, err := cfg.ProvideRateLimiter(nil) - require.NoError(t, err) - require.NotNil(t, limiter) + must.NoError(t, err) + must.NotNil(t, limiter) allowed, err := limiter.Allow(context.Background(), "x") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) }) T.Run("noop provider returns noop", func(t *testing.T) { @@ -71,12 +71,12 @@ func TestConfig_ProvideRateLimiter(T *testing.T) { cfg := &Config{Provider: ProviderNoop} limiter, err := cfg.ProvideRateLimiter(nil) - require.NoError(t, err) - require.NotNil(t, limiter) + must.NoError(t, err) + must.NotNil(t, limiter) allowed, err := limiter.Allow(context.Background(), "x") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) }) T.Run("memory provider returns in-memory limiter", func(t *testing.T) { @@ -88,16 +88,16 @@ func TestConfig_ProvideRateLimiter(T *testing.T) { BurstSize: 1, } limiter, err := cfg.ProvideRateLimiter(nil) - require.NoError(t, err) - require.NotNil(t, limiter) + must.NoError(t, err) + must.NotNil(t, limiter) allowed, err := limiter.Allow(context.Background(), "x") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) allowed, err = limiter.Allow(context.Background(), "x") - require.NoError(t, err) - assert.False(t, allowed) + must.NoError(t, err) + test.False(t, allowed) }) T.Run("redis provider returns redis limiter", func(t *testing.T) { @@ -110,8 +110,8 @@ func TestConfig_ProvideRateLimiter(T *testing.T) { BurstSize: 1, } limiter, err := cfg.ProvideRateLimiter(nil) - require.NoError(t, err) - assert.NotNil(t, limiter) + must.NoError(t, err) + test.NotNil(t, limiter) }) T.Run("unknown provider returns error", func(t *testing.T) { @@ -119,9 +119,9 @@ func TestConfig_ProvideRateLimiter(T *testing.T) { cfg := &Config{Provider: "unknown"} limiter, err := cfg.ProvideRateLimiter(nil) - require.Error(t, err) - assert.Nil(t, limiter) - assert.Contains(t, err.Error(), "unknown") + must.Error(t, err) + test.Nil(t, limiter) + test.StrContains(t, err.Error(), "unknown") }) } @@ -132,8 +132,8 @@ func TestProvideRateLimiterFromConfig(T *testing.T) { t.Parallel() limiter, err := ProvideRateLimiterFromConfig(nil, nil) - require.NoError(t, err) - require.NotNil(t, limiter) + must.NoError(t, err) + must.NotNil(t, limiter) }) T.Run("noop provider returns noop", func(t *testing.T) { @@ -141,8 +141,8 @@ func TestProvideRateLimiterFromConfig(T *testing.T) { cfg := &Config{Provider: ProviderNoop} limiter, err := ProvideRateLimiterFromConfig(cfg, nil) - require.NoError(t, err) - require.NotNil(t, limiter) + must.NoError(t, err) + must.NotNil(t, limiter) }) T.Run("unknown provider wraps error", func(t *testing.T) { @@ -150,9 +150,9 @@ func TestProvideRateLimiterFromConfig(T *testing.T) { cfg := &Config{Provider: "unknown"} limiter, err := ProvideRateLimiterFromConfig(cfg, nil) - require.Error(t, err) - assert.Nil(t, limiter) - assert.Contains(t, err.Error(), "provide rate limiter") + must.Error(t, err) + test.Nil(t, limiter) + test.StrContains(t, err.Error(), "provide rate limiter") }) } @@ -169,7 +169,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("invalid RequestsPerSec", func(t *testing.T) { @@ -182,7 +182,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid BurstSize", func(t *testing.T) { @@ -195,6 +195,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.Error(t, err) + must.Error(t, err) }) } diff --git a/ratelimiting/config/do_test.go b/ratelimiting/config/do_test.go index 6f10fef..875ef46 100644 --- a/ratelimiting/config/do_test.go +++ b/ratelimiting/config/do_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/ratelimiting" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterRateLimiter(T *testing.T) { @@ -24,7 +24,7 @@ func TestRegisterRateLimiter(T *testing.T) { RegisterRateLimiter(i) limiter, err := do.Invoke[ratelimiting.RateLimiter](i) - require.NoError(t, err) - assert.NotNil(t, limiter) + must.NoError(t, err) + test.NotNil(t, limiter) }) } diff --git a/ratelimiting/noop/noop_test.go b/ratelimiting/noop/noop_test.go index f3484e0..f34dada 100644 --- a/ratelimiting/noop/noop_test.go +++ b/ratelimiting/noop/noop_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRateLimiter_Allow(T *testing.T) { @@ -19,8 +19,8 @@ func TestRateLimiter_Allow(T *testing.T) { for range 100 { allowed, err := limiter.Allow(ctx, "any") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) } }) } @@ -33,6 +33,6 @@ func TestRateLimiter_Close(T *testing.T) { limiter := NewRateLimiter() err := limiter.Close() - require.NoError(t, err) + must.NoError(t, err) }) } diff --git a/ratelimiting/ratelimiting_test.go b/ratelimiting/ratelimiting_test.go index be9f922..716fde3 100644 --- a/ratelimiting/ratelimiting_test.go +++ b/ratelimiting/ratelimiting_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestInMemoryRateLimiter_Allow(T *testing.T) { @@ -15,60 +15,60 @@ func TestInMemoryRateLimiter_Allow(T *testing.T) { t.Parallel() limiter, err := NewInMemoryRateLimiter(nil, 10, 3) - require.NoError(t, err) + must.NoError(t, err) defer limiter.Close() ctx := context.Background() allowed, err := limiter.Allow(ctx, "key1") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) allowed, err = limiter.Allow(ctx, "key1") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) allowed, err = limiter.Allow(ctx, "key1") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) allowed, err = limiter.Allow(ctx, "key1") - require.NoError(t, err) - assert.False(t, allowed) + must.NoError(t, err) + test.False(t, allowed) }) T.Run("different keys have independent limits", func(t *testing.T) { t.Parallel() limiter, err := NewInMemoryRateLimiter(nil, 10, 1) - require.NoError(t, err) + must.NoError(t, err) defer limiter.Close() ctx := context.Background() allowed, err := limiter.Allow(ctx, "key1") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) allowed, err = limiter.Allow(ctx, "key2") - require.NoError(t, err) - assert.True(t, allowed) + must.NoError(t, err) + test.True(t, allowed) allowed, err = limiter.Allow(ctx, "key1") - require.NoError(t, err) - assert.False(t, allowed) + must.NoError(t, err) + test.False(t, allowed) allowed, err = limiter.Allow(ctx, "key2") - require.NoError(t, err) - assert.False(t, allowed) + must.NoError(t, err) + test.False(t, allowed) }) T.Run("Close is safe", func(t *testing.T) { t.Parallel() limiter, err := NewInMemoryRateLimiter(nil, 10, 1) - require.NoError(t, err) + must.NoError(t, err) err = limiter.Close() - require.NoError(t, err) + must.NoError(t, err) }) } diff --git a/ratelimiting/redis/redis_test.go b/ratelimiting/redis/redis_test.go index 7c2ecea..e4755ba 100644 --- a/ratelimiting/redis/redis_test.go +++ b/ratelimiting/redis/redis_test.go @@ -10,8 +10,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -46,13 +45,13 @@ func buildTestRateLimiter(t *testing.T) (*rateLimiter, *mockRedisClient) { mp := metrics.NewNoopMetricsProvider() allowedCounter, err := mp.NewInt64Counter(redisName + "_allowed") - require.NoError(t, err) + must.NoError(t, err) rejectedCounter, err := mp.NewInt64Counter(redisName + "_rejected") - require.NoError(t, err) + must.NoError(t, err) errorCounter, err := mp.NewInt64Counter(redisName + "_errors") - require.NoError(t, err) + must.NoError(t, err) return &rateLimiter{ client: client, @@ -75,7 +74,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Addresses: []string{"localhost:6379"}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with empty addresses", func(t *testing.T) { @@ -87,7 +86,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Addresses: []string{}, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with nil addresses", func(t *testing.T) { @@ -97,7 +96,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { cfg := &Config{} - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -114,8 +113,8 @@ func TestNewRedisRateLimiter(T *testing.T) { } rl, err := NewRedisRateLimiter(cfg, nil, 10) - assert.NoError(t, err) - assert.NotNil(t, rl) + test.NoError(t, err) + test.NotNil(t, rl) }) T.Run("with multiple addresses", func(t *testing.T) { @@ -128,8 +127,8 @@ func TestNewRedisRateLimiter(T *testing.T) { } rl, err := NewRedisRateLimiter(cfg, nil, 10) - assert.NoError(t, err) - assert.NotNil(t, rl) + test.NoError(t, err) + test.NotNil(t, rl) }) T.Run("with error creating allowed counter", func(t *testing.T) { @@ -147,8 +146,8 @@ func TestNewRedisRateLimiter(T *testing.T) { } rl, err := NewRedisRateLimiter(cfg, mp, 10) - assert.Error(t, err) - assert.Nil(t, rl) + test.Error(t, err) + test.Nil(t, rl) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -174,8 +173,8 @@ func TestNewRedisRateLimiter(T *testing.T) { } rl, err := NewRedisRateLimiter(cfg, mp, 10) - assert.Error(t, err) - assert.Nil(t, rl) + test.Error(t, err) + test.Nil(t, rl) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -201,8 +200,8 @@ func TestNewRedisRateLimiter(T *testing.T) { } rl, err := NewRedisRateLimiter(cfg, mp, 10) - assert.Error(t, err) - assert.Nil(t, rl) + test.Error(t, err) + test.Nil(t, rl) test.SliceLen(t, 3, mp.NewInt64CounterCalls()) }) @@ -222,11 +221,11 @@ func Test_rateLimiter_Allow(T *testing.T) { client.evalFunc = func(_ context.Context, _ string, _ []string, _ ...any) *redis.Cmd { return cmd } allowed, err := rl.Allow(ctx, "test-key") - assert.NoError(t, err) - assert.True(t, allowed) + test.NoError(t, err) + test.True(t, allowed) - require.Len(t, client.evalCalls, 1) - assert.Equal(t, slidingWindowScript, client.evalCalls[0].script) + must.SliceLen(t, 1, client.evalCalls) + test.EqOp(t, slidingWindowScript, client.evalCalls[0].script) }) T.Run("rejected", func(t *testing.T) { @@ -240,10 +239,10 @@ func Test_rateLimiter_Allow(T *testing.T) { client.evalFunc = func(_ context.Context, _ string, _ []string, _ ...any) *redis.Cmd { return cmd } allowed, err := rl.Allow(ctx, "test-key") - assert.NoError(t, err) - assert.False(t, allowed) + test.NoError(t, err) + test.False(t, allowed) - require.Len(t, client.evalCalls, 1) + must.SliceLen(t, 1, client.evalCalls) }) T.Run("with eval error", func(t *testing.T) { @@ -257,10 +256,10 @@ func Test_rateLimiter_Allow(T *testing.T) { client.evalFunc = func(_ context.Context, _ string, _ []string, _ ...any) *redis.Cmd { return cmd } allowed, err := rl.Allow(ctx, "test-key") - assert.Error(t, err) - assert.False(t, allowed) + test.Error(t, err) + test.False(t, allowed) - require.Len(t, client.evalCalls, 1) + must.SliceLen(t, 1, client.evalCalls) }) } @@ -274,8 +273,8 @@ func Test_rateLimiter_Close(T *testing.T) { client.closeFunc = func() error { return nil } err := rl.Close() - assert.NoError(t, err) - assert.Equal(t, 1, client.closeCalls) + test.NoError(t, err) + test.EqOp(t, 1, client.closeCalls) }) T.Run("with close error", func(t *testing.T) { @@ -285,7 +284,7 @@ func Test_rateLimiter_Close(T *testing.T) { client.closeFunc = func() error { return errors.New("close failed") } err := rl.Close() - assert.Error(t, err) - assert.Equal(t, 1, client.closeCalls) + test.Error(t, err) + test.EqOp(t, 1, client.closeCalls) }) } diff --git a/reflection/ast/helpers_test.go b/reflection/ast/helpers_test.go index de980ce..c9b56f8 100644 --- a/reflection/ast/helpers_test.go +++ b/reflection/ast/helpers_test.go @@ -7,8 +7,8 @@ import ( "path/filepath" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestGetModulePath(T *testing.T) { @@ -19,12 +19,12 @@ func TestGetModulePath(T *testing.T) { dir := t.TempDir() err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("module github.com/example/test\n\ngo 1.21\n"), 0o600) - require.NoError(t, err) + must.NoError(t, err) path, err := GetModulePath(dir) - require.NoError(t, err) - assert.Equal(t, "github.com/example/test", path) + must.NoError(t, err) + test.EqOp(t, "github.com/example/test", path) }) T.Run("returns error when go.mod does not exist", func(t *testing.T) { @@ -32,8 +32,8 @@ func TestGetModulePath(T *testing.T) { path, err := GetModulePath(t.TempDir()) - assert.Empty(t, path) - assert.Error(t, err) + test.EqOp(t, "", path) + test.Error(t, err) }) T.Run("returns error when no module directive found", func(t *testing.T) { @@ -41,12 +41,12 @@ func TestGetModulePath(T *testing.T) { dir := t.TempDir() err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte("go 1.21\n"), 0o600) - require.NoError(t, err) + must.NoError(t, err) path, err := GetModulePath(dir) - assert.Empty(t, path) - assert.Error(t, err) + test.EqOp(t, "", path) + test.Error(t, err) }) } @@ -65,8 +65,8 @@ func TestBuildImportMap(T *testing.T) { result := BuildImportMap(file) - assert.Equal(t, "fmt", result["fmt"]) - assert.Equal(t, "github.com/example/pkg", result["pkg"]) + test.EqOp(t, "fmt", result["fmt"]) + test.EqOp(t, "github.com/example/pkg", result["pkg"]) }) T.Run("handles aliased imports", func(t *testing.T) { @@ -83,7 +83,7 @@ func TestBuildImportMap(T *testing.T) { result := BuildImportMap(file) - assert.Equal(t, "fmt", result["myfmt"]) + test.EqOp(t, "fmt", result["myfmt"]) }) T.Run("excludes blank and dot imports", func(t *testing.T) { @@ -104,7 +104,7 @@ func TestBuildImportMap(T *testing.T) { result := BuildImportMap(file) - assert.Empty(t, result) + test.MapEmpty(t, result) }) T.Run("skips imports with nil path", func(t *testing.T) { @@ -118,7 +118,7 @@ func TestBuildImportMap(T *testing.T) { result := BuildImportMap(file) - assert.Empty(t, result) + test.MapEmpty(t, result) }) } @@ -136,9 +136,9 @@ func TestFilterModuleImports(T *testing.T) { result := FilterModuleImports(imports, "github.com/example/mod") - assert.Len(t, result, 2) - assert.Equal(t, "observability/logging", result["logging"]) - assert.Equal(t, "errors", result["errors"]) + test.MapLen(t, 2, result) + test.EqOp(t, "observability/logging", result["logging"]) + test.EqOp(t, "errors", result["errors"]) }) T.Run("returns empty map when no module imports", func(t *testing.T) { @@ -150,7 +150,7 @@ func TestFilterModuleImports(T *testing.T) { result := FilterModuleImports(imports, "github.com/example/mod") - assert.Empty(t, result) + test.MapEmpty(t, result) }) } @@ -160,33 +160,33 @@ func TestGetTagValue(T *testing.T) { T.Run("extracts tag value", func(t *testing.T) { t.Parallel() - assert.Equal(t, "name", GetTagValue(`json:"name"`, "json")) + test.EqOp(t, "name", GetTagValue(`json:"name"`, "json")) }) T.Run("extracts tag value with omitempty", func(t *testing.T) { t.Parallel() - assert.Equal(t, "name", GetTagValue(`json:"name,omitempty"`, "json")) + test.EqOp(t, "name", GetTagValue(`json:"name,omitempty"`, "json")) }) T.Run("extracts from multiple tags", func(t *testing.T) { t.Parallel() tag := `json:"name" env:"MY_VAR"` - assert.Equal(t, "name", GetTagValue(tag, "json")) - assert.Equal(t, "MY_VAR", GetTagValue(tag, "env")) + test.EqOp(t, "name", GetTagValue(tag, "json")) + test.EqOp(t, "MY_VAR", GetTagValue(tag, "env")) }) T.Run("returns empty for missing key", func(t *testing.T) { t.Parallel() - assert.Equal(t, "", GetTagValue(`json:"name"`, "xml")) + test.EqOp(t, "", GetTagValue(`json:"name"`, "xml")) }) T.Run("handles backtick-wrapped tags", func(t *testing.T) { t.Parallel() - assert.Equal(t, "name", GetTagValue("`json:\"name\"`", "json")) + test.EqOp(t, "name", GetTagValue("`json:\"name\"`", "json")) }) } @@ -216,8 +216,8 @@ func TestGetStructFields(T *testing.T) { fields := GetStructFields(st) - assert.Equal(t, "string", fields["Name"]) - assert.Equal(t, "logging.Logger", fields["Logger"]) + test.EqOp(t, "string", fields["Name"]) + test.EqOp(t, "logging.Logger", fields["Logger"]) }) T.Run("excludes underscore fields", func(t *testing.T) { @@ -236,7 +236,7 @@ func TestGetStructFields(T *testing.T) { fields := GetStructFields(st) - assert.Empty(t, fields) + test.MapEmpty(t, fields) }) T.Run("handles multiple names per field", func(t *testing.T) { @@ -258,7 +258,7 @@ func TestGetStructFields(T *testing.T) { fields := GetStructFields(st) - assert.Equal(t, "int", fields["X"]) - assert.Equal(t, "int", fields["Y"]) + test.EqOp(t, "int", fields["X"]) + test.EqOp(t, "int", fields["Y"]) }) } diff --git a/reflection/utils_test.go b/reflection/utils_test.go index 15f0a48..0b461e0 100644 --- a/reflection/utils_test.go +++ b/reflection/utils_test.go @@ -4,8 +4,8 @@ import ( "reflect" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type exampleStruct struct { @@ -44,8 +44,8 @@ func TestGetTagNameByValue(T *testing.T) { expected := "field1" actual, err := GetTagNameByValue(x, x.Field1, "json") - assert.NoError(t, err) - assert.Equal(t, expected, actual) + test.NoError(t, err) + test.EqOp(t, expected, actual) }) T.Run("unpointered", func(t *testing.T) { @@ -55,16 +55,16 @@ func TestGetTagNameByValue(T *testing.T) { expected := "field1" actual, err := GetTagNameByValue(x, x.Field1, "json") - assert.NoError(t, err) - assert.Equal(t, expected, actual) + test.NoError(t, err) + test.EqOp(t, expected, actual) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() actual, err := GetTagNameByValue(nil, "blah", "json") - assert.Error(t, err) - assert.Empty(t, actual) + test.Error(t, err) + test.EqOp(t, "", actual) }) T.Run("with nil pointer", func(t *testing.T) { @@ -73,16 +73,16 @@ func TestGetTagNameByValue(T *testing.T) { var x *exampleStruct actual, err := GetTagNameByValue(x, "blah", "json") - assert.Error(t, err) - assert.Empty(t, actual) + test.Error(t, err) + test.EqOp(t, "", actual) }) T.Run("with non-struct value", func(t *testing.T) { t.Parallel() actual, err := GetTagNameByValue("not a struct", "blah", "json") - assert.Error(t, err) - assert.Empty(t, actual) + test.Error(t, err) + test.EqOp(t, "", actual) }) T.Run("with no matching field", func(t *testing.T) { @@ -94,8 +94,8 @@ func TestGetTagNameByValue(T *testing.T) { } actual, err := GetTagNameByValue(x, "nonexistent", "json") - assert.Error(t, err) - assert.Empty(t, actual) + test.Error(t, err) + test.EqOp(t, "", actual) }) T.Run("with embedded struct", func(t *testing.T) { @@ -109,8 +109,8 @@ func TestGetTagNameByValue(T *testing.T) { } actual, err := GetTagNameByValue(x, "unique_value", "json") - assert.NoError(t, err) - assert.Equal(t, "field1", actual) + test.NoError(t, err) + test.EqOp(t, "field1", actual) }) T.Run("with pointer-embedded struct", func(t *testing.T) { @@ -124,8 +124,8 @@ func TestGetTagNameByValue(T *testing.T) { } actual, err := GetTagNameByValue(x, "unique_ptr_value", "json") - assert.NoError(t, err) - assert.Equal(t, "field1", actual) + test.NoError(t, err) + test.EqOp(t, "field1", actual) }) T.Run("with nil pointer-embedded struct", func(t *testing.T) { @@ -137,8 +137,8 @@ func TestGetTagNameByValue(T *testing.T) { } actual, err := GetTagNameByValue(x, "unique_nil_embed", "json") - assert.NoError(t, err) - assert.Equal(t, "field3", actual) + test.NoError(t, err) + test.EqOp(t, "field3", actual) }) T.Run("with second field match", func(t *testing.T) { @@ -150,8 +150,8 @@ func TestGetTagNameByValue(T *testing.T) { } actual, err := GetTagNameByValue(x, "bbb", "json") - assert.NoError(t, err) - assert.Equal(t, "field2", actual) + test.NoError(t, err) + test.EqOp(t, "field2", actual) }) T.Run("with unexported fields", func(t *testing.T) { @@ -162,8 +162,8 @@ func TestGetTagNameByValue(T *testing.T) { } actual, err := GetTagNameByValue(x, "unique_exported_val", "json") - assert.NoError(t, err) - assert.Equal(t, "exported", actual) + test.NoError(t, err) + test.EqOp(t, "exported", actual) }) T.Run("with pointer to non-struct", func(t *testing.T) { @@ -171,8 +171,8 @@ func TestGetTagNameByValue(T *testing.T) { s := "hello" actual, err := GetTagNameByValue(&s, "hello", "json") - assert.Error(t, err) - assert.Empty(t, actual) + test.Error(t, err) + test.EqOp(t, "", actual) }) } @@ -183,7 +183,7 @@ func TestGetMethodName(T *testing.T) { t.Parallel() actual := GetMethodName(TestGetMethodName) - assert.Equal(t, "TestGetMethodName", actual) + test.EqOp(t, "TestGetMethodName", actual) }) T.Run("with a method", func(t *testing.T) { @@ -192,14 +192,14 @@ func TestGetMethodName(T *testing.T) { // exampleStruct has no exported methods, so use a known interface method r := reflect.TypeFor[*exampleStruct]() actual := GetMethodName(r.Kind) - assert.Equal(t, "Kind", actual) + test.EqOp(t, "Kind", actual) }) T.Run("with non-function value", func(t *testing.T) { t.Parallel() actual := GetMethodName("not a function") - assert.Empty(t, actual) + test.EqOp(t, "", actual) }) T.Run("with anonymous function", func(t *testing.T) { @@ -207,7 +207,7 @@ func TestGetMethodName(T *testing.T) { fn := func() {} actual := GetMethodName(fn) - assert.NotEmpty(t, actual) + test.NotEq(t, "", actual) }) } @@ -219,10 +219,10 @@ func TestGetFieldTypes(T *testing.T) { x := exampleStruct{} result, err := GetFieldTypes(x) - require.NoError(t, err) + must.NoError(t, err) - assert.Equal(t, "string", result["Field1"]) - assert.Equal(t, "string", result["Field2"]) + test.Eq(t, "string", result["Field1"]) + test.Eq(t, "string", result["Field2"]) }) T.Run("with pointer to struct", func(t *testing.T) { @@ -230,10 +230,10 @@ func TestGetFieldTypes(T *testing.T) { x := &exampleStruct{} result, err := GetFieldTypes(x) - require.NoError(t, err) + must.NoError(t, err) - assert.Equal(t, "string", result["Field1"]) - assert.Equal(t, "string", result["Field2"]) + test.Eq(t, "string", result["Field1"]) + test.Eq(t, "string", result["Field2"]) }) T.Run("with nil pointer to struct", func(t *testing.T) { @@ -241,46 +241,46 @@ func TestGetFieldTypes(T *testing.T) { var x *exampleStruct result, err := GetFieldTypes(x) - require.NoError(t, err) + must.NoError(t, err) - assert.Equal(t, "string", result["Field1"]) - assert.Equal(t, "string", result["Field2"]) + test.Eq(t, "string", result["Field1"]) + test.Eq(t, "string", result["Field2"]) }) T.Run("with nil value", func(t *testing.T) { t.Parallel() result, err := GetFieldTypes(nil) - assert.Error(t, err) - assert.Nil(t, result) + test.Error(t, err) + test.Nil(t, result) }) T.Run("with non-struct value", func(t *testing.T) { t.Parallel() result, err := GetFieldTypes("not a struct") - assert.Error(t, err) - assert.Nil(t, result) + test.Error(t, err) + test.Nil(t, result) }) T.Run("with reflect.Type directly", func(t *testing.T) { t.Parallel() result, err := GetFieldTypes(reflect.TypeFor[exampleStruct]()) - require.NoError(t, err) + must.NoError(t, err) - assert.Equal(t, "string", result["Field1"]) - assert.Equal(t, "string", result["Field2"]) + test.Eq(t, "string", result["Field1"]) + test.Eq(t, "string", result["Field2"]) }) T.Run("with reflect.Type of pointer", func(t *testing.T) { t.Parallel() result, err := GetFieldTypes(reflect.TypeFor[exampleStruct]()) - require.NoError(t, err) + must.NoError(t, err) - assert.Equal(t, "string", result["Field1"]) - assert.Equal(t, "string", result["Field2"]) + test.Eq(t, "string", result["Field1"]) + test.Eq(t, "string", result["Field2"]) }) T.Run("with nested struct", func(t *testing.T) { @@ -288,18 +288,18 @@ func TestGetFieldTypes(T *testing.T) { x := nestedStruct{} result, err := GetFieldTypes(x) - require.NoError(t, err) + must.NoError(t, err) innerMap, ok := result["Inner"].(map[string]any) - require.True(t, ok) - assert.Equal(t, "string", innerMap["Field1"]) - assert.Equal(t, "string", innerMap["Field2"]) + must.True(t, ok) + test.Eq(t, "string", innerMap["Field1"]) + test.Eq(t, "string", innerMap["Field2"]) innerPtrMap, ok := result["InnerPtr"].(map[string]any) - require.True(t, ok) - assert.Equal(t, "string", innerPtrMap["Field1"]) + must.True(t, ok) + test.Eq(t, "string", innerPtrMap["Field1"]) - assert.Equal(t, "string", result["Name"]) + test.Eq(t, "string", result["Name"]) }) T.Run("with unexported fields skipped", func(t *testing.T) { @@ -307,27 +307,27 @@ func TestGetFieldTypes(T *testing.T) { x := unexportedFieldStruct{Exported: "val"} result, err := GetFieldTypes(x) - require.NoError(t, err) + must.NoError(t, err) - assert.Equal(t, "string", result["Exported"]) + test.Eq(t, "string", result["Exported"]) _, hasUnexported := result["unexported"] - assert.False(t, hasUnexported) + test.False(t, hasUnexported) }) T.Run("with non-struct reflect.Type", func(t *testing.T) { t.Parallel() result, err := GetFieldTypes(reflect.TypeFor[string]()) - assert.Error(t, err) - assert.Nil(t, result) + test.Error(t, err) + test.Nil(t, result) }) T.Run("with pointer reflect.Type", func(t *testing.T) { t.Parallel() result, err := GetFieldTypes(reflect.TypeFor[*exampleStruct]()) - require.NoError(t, err) + must.NoError(t, err) - assert.Equal(t, "string", result["Field1"]) + test.Eq(t, "string", result["Field1"]) }) } diff --git a/retry/config_test.go b/retry/config_test.go index d219a63..14f2075 100644 --- a/retry/config_test.go +++ b/retry/config_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_EnsureDefaults(T *testing.T) { @@ -18,10 +18,10 @@ func TestConfig_EnsureDefaults(T *testing.T) { cfg := &Config{} cfg.EnsureDefaults() - assert.Equal(t, uint(3), cfg.MaxAttempts) - assert.Equal(t, 100*time.Millisecond, cfg.InitialDelay) - assert.Equal(t, 5*time.Second, cfg.MaxDelay) - assert.Equal(t, 2.0, cfg.Multiplier) + test.EqOp(t, uint(3), cfg.MaxAttempts) + test.EqOp(t, 100*time.Millisecond, cfg.InitialDelay) + test.EqOp(t, 5*time.Second, cfg.MaxDelay) + test.EqOp(t, 2.0, cfg.Multiplier) }) T.Run("preserves non-zero values", func(t *testing.T) { @@ -35,10 +35,10 @@ func TestConfig_EnsureDefaults(T *testing.T) { } cfg.EnsureDefaults() - assert.Equal(t, uint(7), cfg.MaxAttempts) - assert.Equal(t, 1*time.Second, cfg.InitialDelay) - assert.Equal(t, 10*time.Second, cfg.MaxDelay) - assert.Equal(t, 3.0, cfg.Multiplier) + test.EqOp(t, uint(7), cfg.MaxAttempts) + test.EqOp(t, 1*time.Second, cfg.InitialDelay) + test.EqOp(t, 10*time.Second, cfg.MaxDelay) + test.EqOp(t, 3.0, cfg.Multiplier) }) } @@ -57,7 +57,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("invalid MaxAttempts", func(t *testing.T) { @@ -72,7 +72,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid InitialDelay", func(t *testing.T) { @@ -87,7 +87,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid Multiplier", func(t *testing.T) { @@ -102,6 +102,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { } err := cfg.ValidateWithContext(ctx) - require.Error(t, err) + must.Error(t, err) }) } diff --git a/retry/noop/noop_test.go b/retry/noop/noop_test.go index 2da3c56..3480e83 100644 --- a/retry/noop/noop_test.go +++ b/retry/noop/noop_test.go @@ -5,8 +5,8 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestPolicy_Execute(T *testing.T) { @@ -24,8 +24,8 @@ func TestPolicy_Execute(T *testing.T) { return nil }) - require.NoError(t, err) - assert.Equal(t, 1, attempts) + must.NoError(t, err) + test.EqOp(t, 1, attempts) }) T.Run("executes exactly once on failure", func(t *testing.T) { @@ -41,8 +41,8 @@ func TestPolicy_Execute(T *testing.T) { return expectedErr }) - require.Error(t, err) - assert.Equal(t, expectedErr, err) - assert.Equal(t, 1, attempts) + must.Error(t, err) + test.ErrorIs(t, err, expectedErr) + test.EqOp(t, 1, attempts) }) } diff --git a/retry/retry_test.go b/retry/retry_test.go index 44fa605..c8cd055 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestExponentialBackoffPolicy_Execute(T *testing.T) { @@ -25,8 +25,8 @@ func TestExponentialBackoffPolicy_Execute(T *testing.T) { return nil }) - require.NoError(t, err) - assert.Equal(t, 1, attempts) + must.NoError(t, err) + test.EqOp(t, 1, attempts) }) T.Run("success after retries", func(t *testing.T) { @@ -49,8 +49,8 @@ func TestExponentialBackoffPolicy_Execute(T *testing.T) { return nil }) - require.NoError(t, err) - assert.Equal(t, 3, attempts) + must.NoError(t, err) + test.EqOp(t, 3, attempts) }) T.Run("returns last error after max attempts", func(t *testing.T) { @@ -74,9 +74,9 @@ func TestExponentialBackoffPolicy_Execute(T *testing.T) { return expectedErr }) - require.Error(t, err) - assert.Equal(t, expectedErr, err) - assert.Equal(t, 3, attempts) + must.Error(t, err) + test.ErrorIs(t, err, expectedErr) + test.EqOp(t, 3, attempts) }) T.Run("respects context cancellation", func(t *testing.T) { @@ -94,7 +94,7 @@ func TestExponentialBackoffPolicy_Execute(T *testing.T) { return errors.New("fail") }) - require.Error(t, err) - assert.ErrorIs(t, err, context.Canceled) + must.Error(t, err) + test.ErrorIs(t, err, context.Canceled) }) } diff --git a/routing/chi/config_test.go b/routing/chi/config_test.go index 1e1c525..f5e5d4d 100644 --- a/routing/chi/config_test.go +++ b/routing/chi/config_test.go @@ -3,7 +3,7 @@ package chi import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -17,6 +17,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { ServiceName: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/routing/chi/middleware_test.go b/routing/chi/middleware_test.go index 420c632..814ad58 100644 --- a/routing/chi/middleware_test.go +++ b/routing/chi/middleware_test.go @@ -8,7 +8,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestBuildLoggingMiddleware(T *testing.T) { @@ -21,7 +21,7 @@ func TestBuildLoggingMiddleware(T *testing.T) { tracer := tracing.NewTracerForTest("") middleware := buildLoggingMiddleware(logging.NewNoopLogger(), tracer, false) - assert.NotNil(t, middleware) + test.NotNil(t, middleware) hf := http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {}) diff --git a/routing/chi/request_id_test.go b/routing/chi/request_id_test.go index 94eff25..6788c05 100644 --- a/routing/chi/request_id_test.go +++ b/routing/chi/request_id_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/identifiers" chimiddleware "github.com/go-chi/chi/v5/middleware" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRequestIDFunc(T *testing.T) { @@ -22,9 +22,9 @@ func TestRequestIDFunc(T *testing.T) { ctx := context.WithValue(t.Context(), chimiddleware.RequestIDKey, expected) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "/", http.NoBody) - require.NoError(t, err) + must.NoError(t, err) actual := RequestIDFunc(req) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) } diff --git a/routing/chi/routeparams_test.go b/routing/chi/routeparams_test.go index bfc9a3b..fe41ba7 100644 --- a/routing/chi/routeparams_test.go +++ b/routing/chi/routeparams_test.go @@ -9,7 +9,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" "github.com/go-chi/chi/v5" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestNewRouteParamManager(T *testing.T) { @@ -18,7 +18,7 @@ func TestNewRouteParamManager(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewRouteParamManager()) + test.NotNil(t, NewRouteParamManager()) }) } @@ -48,7 +48,7 @@ func Test_BuildRouteParamIDFetcher(T *testing.T) { ) actual := fn(req) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) T.Run("with invalid value somehow", func(t *testing.T) { @@ -77,7 +77,7 @@ func Test_BuildRouteParamIDFetcher(T *testing.T) { ) actual := fn(req) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) } @@ -108,6 +108,6 @@ func Test_BuildRouteParamStringIDFetcher(T *testing.T) { ) actual := fn(req) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) } diff --git a/routing/chi/router_test.go b/routing/chi/router_test.go index bf0ab60..2e60b1e 100644 --- a/routing/chi/router_test.go +++ b/routing/chi/router_test.go @@ -11,8 +11,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/routing" "github.com/go-chi/chi/v5" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func buildRouterForTest() routing.Router { @@ -25,7 +25,7 @@ func TestNewRouter(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, NewRouter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), &Config{})) + test.NotNil(t, NewRouter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), &Config{})) }) } @@ -35,7 +35,7 @@ func Test_buildChiMux(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, buildChiMux(logging.NewNoopLogger(), tracing.NewTracerForTest(t.Name()), metrics.NewNoopMetricsProvider(), &Config{})) + test.NotNil(t, buildChiMux(logging.NewNoopLogger(), tracing.NewTracerForTest(t.Name()), metrics.NewNoopMetricsProvider(), &Config{})) }) } @@ -45,7 +45,7 @@ func Test_convertMiddleware(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotNil(t, convertMiddleware(func(http.Handler) http.Handler { return nil })) + test.NotNil(t, convertMiddleware(func(http.Handler) http.Handler { return nil })) }) } @@ -70,7 +70,7 @@ func Test_router_AddRoute(T *testing.T) { } for _, method := range methods { - assert.NoError(t, r.AddRoute(method, "/path", nil)) + test.NoError(t, r.AddRoute(method, "/path", nil)) } }) @@ -79,7 +79,7 @@ func Test_router_AddRoute(T *testing.T) { r := buildRouterForTest() - assert.Error(t, r.AddRoute("blah", "/path", nil)) + test.Error(t, r.AddRoute("blah", "/path", nil)) }) } @@ -151,7 +151,7 @@ func Test_router_Handler(T *testing.T) { r := buildRouterForTest() - assert.NotNil(t, r.Handler()) + test.NotNil(t, r.Handler()) }) } @@ -175,7 +175,7 @@ func Test_router_LogRoutes(T *testing.T) { r := buildRouterForTest() - assert.NoError(t, r.AddRoute(http.MethodGet, "/path", nil)) + test.NoError(t, r.AddRoute(http.MethodGet, "/path", nil)) r.Routes() }) @@ -237,7 +237,7 @@ func Test_router_Route(T *testing.T) { r := buildRouterForTest() - assert.NotNil(t, r.Route("/test", func(routing.Router) {})) + test.NotNil(t, r.Route("/test", func(routing.Router) {})) }) } @@ -261,7 +261,7 @@ func Test_router_WithMiddleware(T *testing.T) { r := buildRouterForTest() - assert.NotNil(t, r.WithMiddleware()) + test.NotNil(t, r.WithMiddleware()) }) } @@ -273,7 +273,7 @@ func Test_router_clone(T *testing.T) { r := buildRouter(nil, nil, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), &Config{}) - assert.NotNil(t, r.clone()) + test.NotNil(t, r.clone()) }) } @@ -289,11 +289,11 @@ func Test_router_BuildRouteParamIDFetcher(T *testing.T) { exampleKey := "blah" rf := r.BuildRouteParamIDFetcher(l, exampleKey, "desc") - assert.NotNil(t, rf) + test.NotNil(t, rf) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/blah", http.NoBody) - assert.NoError(t, err) - require.NotNil(t, req) + test.NoError(t, err) + must.NotNil(t, req) expected := uint64(123456) @@ -305,7 +305,7 @@ func Test_router_BuildRouteParamIDFetcher(T *testing.T) { })) actual := rf(req) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) T.Run("without appropriate value attached to context", func(t *testing.T) { @@ -317,14 +317,14 @@ func Test_router_BuildRouteParamIDFetcher(T *testing.T) { exampleKey := "blah" rf := r.BuildRouteParamIDFetcher(l, exampleKey, "desc") - assert.NotNil(t, rf) + test.NotNil(t, rf) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/blah", http.NoBody) - assert.NoError(t, err) - require.NotNil(t, req) + test.NoError(t, err) + must.NotNil(t, req) actual := rf(req) - assert.Zero(t, actual) + test.EqOp(t, uint64(0), actual) }) } @@ -339,11 +339,11 @@ func Test_router_BuildRouteParamStringIDFetcher(T *testing.T) { exampleKey := "blah" rf := r.BuildRouteParamStringIDFetcher(exampleKey) - assert.NotNil(t, rf) + test.NotNil(t, rf) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/blah", http.NoBody) - assert.NoError(t, err) - require.NotNil(t, req) + test.NoError(t, err) + must.NotNil(t, req) expected := "fake_user_id" @@ -355,6 +355,6 @@ func Test_router_BuildRouteParamStringIDFetcher(T *testing.T) { })) actual := rf(req) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) } diff --git a/routing/config/config_test.go b/routing/config/config_test.go index b601d65..d8a2772 100644 --- a/routing/config/config_test.go +++ b/routing/config/config_test.go @@ -8,8 +8,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/verygoodsoftwarenotvirus/platform/v5/routing/chi" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -23,7 +23,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ProviderChi, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with invalid provider", func(t *testing.T) { @@ -34,7 +34,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "bogus", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -50,8 +50,8 @@ func TestProvideRouter(T *testing.T) { } router, err := ProvideRouter(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - assert.NotNil(t, router) + must.NoError(t, err) + test.NotNil(t, router) }) T.Run("with unknown provider", func(t *testing.T) { @@ -62,8 +62,8 @@ func TestProvideRouter(T *testing.T) { } router, err := ProvideRouter(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - assert.Nil(t, router) - assert.Error(t, err) + test.Nil(t, router) + test.Error(t, err) }) } @@ -79,8 +79,8 @@ func TestConfig_ProvideRouter(T *testing.T) { } router, err := cfg.ProvideRouter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - assert.NotNil(t, router) + must.NoError(t, err) + test.NotNil(t, router) }) T.Run("with unknown provider", func(t *testing.T) { @@ -91,8 +91,8 @@ func TestConfig_ProvideRouter(T *testing.T) { } router, err := cfg.ProvideRouter(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - assert.Nil(t, router) - assert.Error(t, err) + test.Nil(t, router) + test.Error(t, err) }) } @@ -107,8 +107,8 @@ func TestProvideRouteParamManager(T *testing.T) { } manager, err := ProvideRouteParamManager(cfg) - require.NoError(t, err) - assert.NotNil(t, manager) + must.NoError(t, err) + test.NotNil(t, manager) }) T.Run("with unknown provider", func(t *testing.T) { @@ -119,8 +119,8 @@ func TestProvideRouteParamManager(T *testing.T) { } manager, err := ProvideRouteParamManager(cfg) - assert.Nil(t, manager) - assert.Error(t, err) + test.Nil(t, manager) + test.Error(t, err) }) } @@ -136,8 +136,8 @@ func TestProvideRouterViaConfig(T *testing.T) { } router, err := ProvideRouterViaConfig(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - require.NoError(t, err) - assert.NotNil(t, router) + must.NoError(t, err) + test.NotNil(t, router) }) T.Run("with unknown provider", func(t *testing.T) { @@ -148,7 +148,7 @@ func TestProvideRouterViaConfig(T *testing.T) { } router, err := ProvideRouterViaConfig(cfg, logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider()) - assert.Nil(t, router) - assert.Error(t, err) + test.Nil(t, router) + test.Error(t, err) }) } diff --git a/routing/config/do_test.go b/routing/config/do_test.go index a574a3e..47c75df 100644 --- a/routing/config/do_test.go +++ b/routing/config/do_test.go @@ -6,8 +6,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/routing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterRouteParamManager(T *testing.T) { @@ -24,7 +24,7 @@ func TestRegisterRouteParamManager(T *testing.T) { RegisterRouteParamManager(i) manager, err := do.Invoke[routing.RouteParamManager](i) - require.NoError(t, err) - assert.NotNil(t, manager) + must.NoError(t, err) + test.NotNil(t, manager) }) } diff --git a/search/text/algolia/algolia_test.go b/search/text/algolia/algolia_test.go index f7f2640..fee0b82 100644 --- a/search/text/algolia/algolia_test.go +++ b/search/text/algolia/algolia_test.go @@ -7,7 +7,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) type example struct { @@ -25,8 +25,8 @@ func TestProvideIndexManager(T *testing.T) { tracerProvider := tracing.NewNoopTracerProvider() im, err := ProvideIndexManager[example](logger, tracerProvider, &Config{}, "test", cbnoop.NewCircuitBreaker()) - assert.NoError(t, err) - assert.NotNil(t, im) + test.NoError(t, err) + test.NotNil(t, im) }) T.Run("with nil config", func(t *testing.T) { @@ -36,7 +36,7 @@ func TestProvideIndexManager(T *testing.T) { tracerProvider := tracing.NewNoopTracerProvider() im, err := ProvideIndexManager[example](logger, tracerProvider, nil, "test", cbnoop.NewCircuitBreaker()) - assert.Error(t, err) - assert.Nil(t, im) + test.Error(t, err) + test.Nil(t, im) }) } diff --git a/search/text/algolia/config_test.go b/search/text/algolia/config_test.go index 3059753..b690945 100644 --- a/search/text/algolia/config_test.go +++ b/search/text/algolia/config_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig(T *testing.T) { @@ -14,9 +14,9 @@ func TestConfig(T *testing.T) { t.Parallel() cfg := &Config{} - assert.Empty(t, cfg.AppID) - assert.Empty(t, cfg.APIKey) - assert.Equal(t, time.Duration(0), cfg.Timeout) + test.EqOp(t, "", cfg.AppID) + test.EqOp(t, "", cfg.APIKey) + test.EqOp(t, time.Duration(0), cfg.Timeout) }) T.Run("with values", func(t *testing.T) { @@ -28,8 +28,8 @@ func TestConfig(T *testing.T) { Timeout: 5 * time.Second, } - assert.Equal(t, "test-app-id", cfg.AppID) - assert.Equal(t, "test-api-key", cfg.APIKey) - assert.Equal(t, 5*time.Second, cfg.Timeout) + test.EqOp(t, "test-app-id", cfg.AppID) + test.EqOp(t, "test-api-key", cfg.APIKey) + test.EqOp(t, 5*time.Second, cfg.Timeout) }) } diff --git a/search/text/algolia/index_test.go b/search/text/algolia/index_test.go index 05aada2..1fb35d1 100644 --- a/search/text/algolia/index_test.go +++ b/search/text/algolia/index_test.go @@ -15,7 +15,7 @@ import ( algoliasearch "github.com/algolia/algoliasearch-client-go/v3/algolia/search" algoliatransport "github.com/algolia/algoliasearch-client-go/v3/algolia/transport" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) var _ algoliatransport.Requester = (*testRequester)(nil) @@ -122,8 +122,8 @@ func TestIndexManager_Index(T *testing.T) { im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + test.Error(t, err) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) }) T.Run("with unmarshalable value", func(t *testing.T) { @@ -132,7 +132,7 @@ func TestIndexManager_Index(T *testing.T) { im := buildTestIndexManager(t) err := im.Index(context.Background(), "id", make(chan int)) - assert.Error(t, err) + test.Error(t, err) }) T.Run("with valid value but invalid credentials", func(t *testing.T) { @@ -141,7 +141,7 @@ func TestIndexManager_Index(T *testing.T) { im := buildTestIndexManager(t) err := im.Index(context.Background(), "id", map[string]string{"id": "test", "name": "example"}) - assert.Error(t, err) + test.Error(t, err) }) T.Run("with non-object JSON value", func(t *testing.T) { @@ -150,7 +150,7 @@ func TestIndexManager_Index(T *testing.T) { im := buildTestIndexManager(t) err := im.Index(context.Background(), "id", "just a string") - assert.Error(t, err) + test.Error(t, err) }) T.Run("with successful index", func(t *testing.T) { @@ -168,7 +168,7 @@ func TestIndexManager_Index(T *testing.T) { im := buildTestIndexManagerWithMockServer(t, handler, cb) err := im.Index(context.Background(), "123", map[string]string{"id": "123", "name": "example"}) - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -185,9 +185,9 @@ func TestIndexManager_Search(T *testing.T) { im := buildTestIndexManagerWithCircuitBreaker(t, cb) results, err := im.Search(context.Background(), "query") - assert.Error(t, err) - assert.Nil(t, results) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + test.Error(t, err) + test.Nil(t, results) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) }) T.Run("with empty query", func(t *testing.T) { @@ -200,9 +200,9 @@ func TestIndexManager_Search(T *testing.T) { im := buildTestIndexManagerWithCircuitBreaker(t, cb) results, err := im.Search(context.Background(), "") - assert.Error(t, err) - assert.Nil(t, results) - assert.Equal(t, ErrEmptyQueryProvided, err) + test.Error(t, err) + test.Nil(t, results) + test.ErrorIs(t, err, ErrEmptyQueryProvided) }) T.Run("with valid query but invalid credentials", func(t *testing.T) { @@ -216,8 +216,8 @@ func TestIndexManager_Search(T *testing.T) { im := buildTestIndexManagerWithCircuitBreaker(t, cb) results, err := im.Search(context.Background(), "test query") - assert.Error(t, err) - assert.Nil(t, results) + test.Error(t, err) + test.Nil(t, results) }) T.Run("with successful search results", func(t *testing.T) { @@ -236,9 +236,9 @@ func TestIndexManager_Search(T *testing.T) { im := buildTestIndexManagerWithMockServer(t, handler, cb) results, err := im.Search(context.Background(), "test query") - assert.NoError(t, err) - assert.NotNil(t, results) - assert.Len(t, results, 1) + test.NoError(t, err) + test.NotNil(t, results) + test.SliceLen(t, 1, results) }) T.Run("with empty search results", func(t *testing.T) { @@ -257,9 +257,9 @@ func TestIndexManager_Search(T *testing.T) { im := buildTestIndexManagerWithMockServer(t, handler, cb) results, err := im.Search(context.Background(), "test query") - assert.NoError(t, err) - assert.NotNil(t, results) - assert.Empty(t, results) + test.NoError(t, err) + test.NotNil(t, results) + test.SliceEmpty(t, results) }) T.Run("with multiple search results", func(t *testing.T) { @@ -278,14 +278,14 @@ func TestIndexManager_Search(T *testing.T) { im := buildTestIndexManagerWithMockServer(t, handler, cb) results, err := im.Search(context.Background(), "test query") - assert.NoError(t, err) - assert.Len(t, results, 3) - assert.Equal(t, "abc", results[0].ID) - assert.Equal(t, "first", results[0].Name) - assert.Equal(t, "def", results[1].ID) - assert.Equal(t, "second", results[1].Name) - assert.Equal(t, "ghi", results[2].ID) - assert.Equal(t, "third", results[2].Name) + test.NoError(t, err) + test.SliceLen(t, 3, results) + test.EqOp(t, "abc", results[0].ID) + test.EqOp(t, "first", results[0].Name) + test.EqOp(t, "def", results[1].ID) + test.EqOp(t, "second", results[1].Name) + test.EqOp(t, "ghi", results[2].ID) + test.EqOp(t, "third", results[2].Name) }) T.Run("when unmarshalling search result fails", func(t *testing.T) { @@ -303,8 +303,8 @@ func TestIndexManager_Search(T *testing.T) { im := buildTestIndexManagerWithMockServer(t, handler, cb) results, err := im.Search(context.Background(), "test query") - assert.Error(t, err) - assert.Nil(t, results) + test.Error(t, err) + test.Nil(t, results) }) T.Run("with successful search results without objectID", func(t *testing.T) { @@ -323,9 +323,9 @@ func TestIndexManager_Search(T *testing.T) { im := buildTestIndexManagerWithMockServer(t, handler, cb) results, err := im.Search(context.Background(), "test query") - assert.NoError(t, err) - assert.NotNil(t, results) - assert.Len(t, results, 1) + test.NoError(t, err) + test.NotNil(t, results) + test.SliceLen(t, 1, results) }) } @@ -342,8 +342,8 @@ func TestIndexManager_Delete(T *testing.T) { im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Delete(context.Background(), "id") - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + test.Error(t, err) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) }) T.Run("with invalid credentials", func(t *testing.T) { @@ -357,7 +357,7 @@ func TestIndexManager_Delete(T *testing.T) { im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Delete(context.Background(), "some-id") - assert.Error(t, err) + test.Error(t, err) }) T.Run("with successful deletion", func(t *testing.T) { @@ -376,7 +376,7 @@ func TestIndexManager_Delete(T *testing.T) { im := buildTestIndexManagerWithMockServer(t, handler, cb) err := im.Delete(context.Background(), "some-id") - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -393,8 +393,8 @@ func TestIndexManager_Wipe(T *testing.T) { im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Wipe(context.Background()) - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + test.Error(t, err) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) }) T.Run("with invalid credentials", func(t *testing.T) { @@ -408,7 +408,7 @@ func TestIndexManager_Wipe(T *testing.T) { im := buildTestIndexManagerWithCircuitBreaker(t, cb) err := im.Wipe(context.Background()) - assert.Error(t, err) + test.Error(t, err) }) T.Run("with successful wipe", func(t *testing.T) { @@ -427,6 +427,6 @@ func TestIndexManager_Wipe(T *testing.T) { im := buildTestIndexManagerWithMockServer(t, handler, cb) err := im.Wipe(context.Background()) - assert.NoError(t, err) + test.NoError(t, err) }) } diff --git a/search/text/config/config_test.go b/search/text/config/config_test.go index 61221a7..e7e8de3 100644 --- a/search/text/config/config_test.go +++ b/search/text/config/config_test.go @@ -14,7 +14,6 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/search/text/elasticsearch" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" "go.opentelemetry.io/otel/metric" ) @@ -32,7 +31,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("algolia provider", func(t *testing.T) { @@ -47,7 +46,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("invalid provider", func(t *testing.T) { @@ -58,7 +57,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "invalid-provider", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("elasticsearch provider without elasticsearch config", func(t *testing.T) { @@ -69,7 +68,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: ElasticsearchProvider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("algolia provider without algolia config", func(t *testing.T) { @@ -80,7 +79,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: AlgoliaProvider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("empty provider", func(t *testing.T) { @@ -92,7 +91,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } // Empty provider should be valid (it will default to noop) - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("provider with extra whitespace", func(t *testing.T) { @@ -107,7 +106,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } // Provider with whitespace should be invalid (validation is strict) - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("provider case insensitive", func(t *testing.T) { @@ -122,7 +121,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { } // Provider should be case sensitive (validation is strict) - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("nil context", func(t *testing.T) { @@ -135,7 +134,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(context.TODO())) + test.NoError(t, cfg.ValidateWithContext(context.TODO())) }) } @@ -149,16 +148,16 @@ func TestConfig_ZeroValue(T *testing.T) { cfg := &Config{} // Zero value should be valid (it will default to noop) - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("zero value fields", func(t *testing.T) { t.Parallel() cfg := &Config{} - assert.Equal(t, "", cfg.Provider) - assert.Nil(t, cfg.Algolia) - assert.Nil(t, cfg.Elasticsearch) + test.EqOp(t, "", cfg.Provider) + test.Nil(t, cfg.Algolia) + test.Nil(t, cfg.Elasticsearch) }) } @@ -168,21 +167,21 @@ func TestConfig_Constants(T *testing.T) { T.Run("provider constants have expected values", func(t *testing.T) { t.Parallel() - assert.Equal(t, "elasticsearch", ElasticsearchProvider) - assert.Equal(t, "algolia", AlgoliaProvider) + test.EqOp(t, "elasticsearch", ElasticsearchProvider) + test.EqOp(t, "algolia", AlgoliaProvider) }) T.Run("provider constants are not empty", func(t *testing.T) { t.Parallel() - assert.NotEmpty(t, ElasticsearchProvider) - assert.NotEmpty(t, AlgoliaProvider) + test.NotEq(t, "", ElasticsearchProvider) + test.NotEq(t, "", AlgoliaProvider) }) T.Run("provider constants are different", func(t *testing.T) { t.Parallel() - assert.NotEqual(t, ElasticsearchProvider, AlgoliaProvider) + test.NotEq(t, ElasticsearchProvider, AlgoliaProvider) }) } @@ -206,8 +205,8 @@ func TestConfig_ProvideIndex(T *testing.T) { tracerProvider := tracing.NewNoopTracerProvider() metricsProvider := metrics.NewNoopMetricsProvider() index, err := ProvideIndex[testStruct](ctx, logger, tracerProvider, metricsProvider, cfg, "test-index") - assert.Error(t, err) - assert.Nil(t, index) + test.Error(t, err) + test.Nil(t, index) }) T.Run("algolia provider", func(t *testing.T) { @@ -227,8 +226,8 @@ func TestConfig_ProvideIndex(T *testing.T) { tracerProvider := tracing.NewNoopTracerProvider() metricsProvider := metrics.NewNoopMetricsProvider() index, err := ProvideIndex[testStruct](ctx, logger, tracerProvider, metricsProvider, cfg, "test-index") - assert.NoError(t, err) - assert.NotNil(t, index) + test.NoError(t, err) + test.NotNil(t, index) }) T.Run("unknown provider returns noop", func(t *testing.T) { @@ -243,8 +242,8 @@ func TestConfig_ProvideIndex(T *testing.T) { tracerProvider := tracing.NewNoopTracerProvider() metricsProvider := metrics.NewNoopMetricsProvider() index, err := ProvideIndex[testStruct](ctx, logger, tracerProvider, metricsProvider, cfg, "test-index") - assert.NoError(t, err) - assert.NotNil(t, index) + test.NoError(t, err) + test.NotNil(t, index) }) T.Run("empty provider returns noop", func(t *testing.T) { @@ -259,8 +258,8 @@ func TestConfig_ProvideIndex(T *testing.T) { tracerProvider := tracing.NewNoopTracerProvider() metricsProvider := metrics.NewNoopMetricsProvider() index, err := ProvideIndex[testStruct](ctx, logger, tracerProvider, metricsProvider, cfg, "test-index") - assert.NoError(t, err) - assert.NotNil(t, index) + test.NoError(t, err) + test.NotNil(t, index) }) T.Run("provider with whitespace returns noop", func(t *testing.T) { @@ -275,8 +274,8 @@ func TestConfig_ProvideIndex(T *testing.T) { tracerProvider := tracing.NewNoopTracerProvider() metricsProvider := metrics.NewNoopMetricsProvider() index, err := ProvideIndex[testStruct](ctx, logger, tracerProvider, metricsProvider, cfg, "test-index") - assert.NoError(t, err) - assert.NotNil(t, index) + test.NoError(t, err) + test.NotNil(t, index) }) T.Run("circuit breaker init failure", func(t *testing.T) { @@ -304,9 +303,9 @@ func TestConfig_ProvideIndex(T *testing.T) { logger := logging.NewNoopLogger() tracerProvider := tracing.NewNoopTracerProvider() index, err := ProvideIndex[testStruct](ctx, logger, tracerProvider, mp, cfg, "test-index") - assert.Error(t, err) - assert.Nil(t, index) - assert.Contains(t, err.Error(), "circuit breaker") + test.Error(t, err) + test.Nil(t, index) + test.StrContains(t, err.Error(), "circuit breaker") test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) diff --git a/search/text/elasticsearch/config_test.go b/search/text/elasticsearch/config_test.go index 8070733..54a787a 100644 --- a/search/text/elasticsearch/config_test.go +++ b/search/text/elasticsearch/config_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig(T *testing.T) { @@ -14,11 +14,11 @@ func TestConfig(T *testing.T) { t.Parallel() cfg := &Config{} - assert.Empty(t, cfg.Address) - assert.Empty(t, cfg.Username) - assert.Empty(t, cfg.Password) - assert.Nil(t, cfg.CACert) - assert.Equal(t, time.Duration(0), cfg.IndexOperationTimeout) + test.EqOp(t, "", cfg.Address) + test.EqOp(t, "", cfg.Username) + test.EqOp(t, "", cfg.Password) + test.Nil(t, cfg.CACert) + test.EqOp(t, time.Duration(0), cfg.IndexOperationTimeout) }) T.Run("with values", func(t *testing.T) { @@ -32,10 +32,10 @@ func TestConfig(T *testing.T) { IndexOperationTimeout: 5 * time.Second, } - assert.Equal(t, "http://localhost:9200", cfg.Address) - assert.Equal(t, "elastic", cfg.Username) - assert.Equal(t, "password", cfg.Password) - assert.Equal(t, []byte("cert"), cfg.CACert) - assert.Equal(t, 5*time.Second, cfg.IndexOperationTimeout) + test.EqOp(t, "http://localhost:9200", cfg.Address) + test.EqOp(t, "elastic", cfg.Username) + test.EqOp(t, "password", cfg.Password) + test.Eq(t, []byte("cert"), cfg.CACert) + test.EqOp(t, 5*time.Second, cfg.IndexOperationTimeout) }) } diff --git a/search/text/elasticsearch/elasticsearch_test.go b/search/text/elasticsearch/elasticsearch_test.go index 6aa0577..7eb88ce 100644 --- a/search/text/elasticsearch/elasticsearch_test.go +++ b/search/text/elasticsearch/elasticsearch_test.go @@ -17,8 +17,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" elasticsearchcontainers "github.com/testcontainers/testcontainers-go/modules/elasticsearch" ) @@ -41,8 +41,8 @@ func buildEsTestInfra(t *testing.T) *esTestInfra { "elasticsearch:8.10.2", elasticsearchcontainers.WithPassword("arbitraryPassword"), ) - require.NoError(t, err) - require.NotNil(t, elasticsearchContainer) + must.NoError(t, err) + must.NotNil(t, elasticsearchContainer) cfg := &Config{ Address: elasticsearchContainer.Settings.Address, @@ -80,15 +80,15 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() indexName := "ensure_create_" + identifiers.New() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, indexName, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - assert.NotNil(t, im) + must.NoError(t, err) + test.NotNil(t, im) searchable := &example{ ID: identifiers.New(), Name: "test document", } - assert.NoError(t, im.Index(ctx, searchable.ID, searchable)) + test.NoError(t, im.Index(ctx, searchable.ID, searchable)) }) T.Run("ensureIndices handles existing index", func(t *testing.T) { @@ -97,21 +97,21 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() indexName := "ensure_existing_" + identifiers.New() im1, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, indexName, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) im2, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, indexName, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) - assert.NotNil(t, im1) - assert.NotNil(t, im2) + test.NotNil(t, im1) + test.NotNil(t, im2) searchable := &example{ ID: identifiers.New(), Name: "test document", } - assert.NoError(t, im1.Index(ctx, searchable.ID, searchable)) - assert.NoError(t, im2.Index(ctx, searchable.ID+"_2", searchable)) + test.NoError(t, im1.Index(ctx, searchable.ID, searchable)) + test.NoError(t, im2.Index(ctx, searchable.ID+"_2", searchable)) }) // --- ProvideIndexManager --- @@ -121,8 +121,8 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "provide_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - assert.NoError(t, err) - assert.NotNil(t, im) + test.NoError(t, err) + test.NotNil(t, im) }) T.Run("ProvideIndexManager with logger and tracer", func(t *testing.T) { @@ -133,8 +133,8 @@ func TestElasticsearch_Container(T *testing.T) { tracerProvider := tracing.NewNoopTracerProvider() im, err := ProvideIndexManager[example](ctx, logger, tracerProvider, infra.cfg, "provide_lt_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - assert.NoError(t, err) - assert.NotNil(t, im) + test.NoError(t, err) + test.NotNil(t, im) }) // --- elasticsearchIsReadyToInit --- @@ -146,7 +146,7 @@ func TestElasticsearch_Container(T *testing.T) { logger := logging.NewNoopLogger() ready := elasticsearchIsReadyToInit(ctx, infra.cfg, logger, 5) - assert.True(t, ready) + test.True(t, ready) }) // --- provideElasticsearchClient --- @@ -155,8 +155,8 @@ func TestElasticsearch_Container(T *testing.T) { t.Parallel() client, err := provideElasticsearchClient(infra.cfg) - assert.NoError(t, err) - assert.NotNil(t, client) + test.NoError(t, err) + test.NotNil(t, client) }) // --- complete lifecycle --- @@ -166,24 +166,24 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "lifecycle_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - assert.NoError(t, err) - assert.NotNil(t, im) + test.NoError(t, err) + test.NotNil(t, im) searchable := &example{ ID: identifiers.New(), Name: t.Name(), } - assert.NoError(t, im.Index(ctx, searchable.ID, searchable)) + test.NoError(t, im.Index(ctx, searchable.ID, searchable)) time.Sleep(5 * time.Second) results, err := im.Search(ctx, searchable.Name[0:2]) - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Equal(t, searchable, results[0]) + test.NoError(t, err) + test.SliceLen(t, 1, results) + test.Eq(t, searchable, results[0]) - assert.NoError(t, im.Delete(ctx, searchable.ID)) + test.NoError(t, im.Delete(ctx, searchable.ID)) }) // --- Index --- @@ -193,14 +193,14 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "idx_ok_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) searchable := &example{ ID: identifiers.New(), Name: t.Name(), } - assert.NoError(t, im.Index(ctx, searchable.ID, searchable)) + test.NoError(t, im.Index(ctx, searchable.ID, searchable)) }) T.Run("Index json marshaling error", func(t *testing.T) { @@ -208,13 +208,13 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "idx_json_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) invalid := &invalidJSON{ Channel: make(chan int), } - assert.Error(t, im.Index(ctx, "test-id", invalid)) + test.Error(t, im.Index(ctx, "test-id", invalid)) }) T.Run("Index with noop circuit breaker", func(t *testing.T) { @@ -223,14 +223,14 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() cb := cbnoop.NewCircuitBreaker() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "idx_cb_"+identifiers.New(), cb) - require.NoError(t, err) + must.NoError(t, err) searchable := &example{ ID: identifiers.New(), Name: t.Name(), } - assert.NoError(t, im.Index(ctx, searchable.ID, searchable)) + test.NoError(t, im.Index(ctx, searchable.ID, searchable)) }) // --- Search --- @@ -240,20 +240,20 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "search_ok_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) searchable := &example{ ID: identifiers.New(), Name: "test search document", } - require.NoError(t, im.Index(ctx, searchable.ID, searchable)) + must.NoError(t, im.Index(ctx, searchable.ID, searchable)) time.Sleep(2 * time.Second) results, err := im.Search(ctx, "test") - assert.NoError(t, err) - assert.Len(t, results, 1) - assert.Equal(t, searchable.ID, results[0].ID) + test.NoError(t, err) + test.SliceLen(t, 1, results) + test.EqOp(t, searchable.ID, results[0].ID) }) T.Run("Search empty query error", func(t *testing.T) { @@ -261,12 +261,12 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "search_empty_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) results, err := im.Search(ctx, "") - assert.Error(t, err) - assert.Nil(t, results) - assert.Equal(t, ErrEmptyQueryProvided, err) + test.Error(t, err) + test.Nil(t, results) + test.ErrorIs(t, err, ErrEmptyQueryProvided) }) T.Run("Search no results found", func(t *testing.T) { @@ -274,11 +274,11 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "search_noresult_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) results, err := im.Search(ctx, "nonexistent document") - assert.NoError(t, err) - assert.Len(t, results, 0) + test.NoError(t, err) + test.SliceLen(t, 0, results) }) // --- Delete --- @@ -288,15 +288,15 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "del_ok_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) searchable := &example{ ID: identifiers.New(), Name: "test delete document", } - require.NoError(t, im.Index(ctx, searchable.ID, searchable)) + must.NoError(t, im.Index(ctx, searchable.ID, searchable)) - assert.NoError(t, im.Delete(ctx, searchable.ID)) + test.NoError(t, im.Delete(ctx, searchable.ID)) }) T.Run("Delete non-existent document", func(t *testing.T) { @@ -304,9 +304,9 @@ func TestElasticsearch_Container(T *testing.T) { ctx := t.Context() im, err := ProvideIndexManager[example](ctx, nil, nil, infra.cfg, "del_nf_"+identifiers.New(), cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) - assert.NoError(t, im.Delete(ctx, "non-existent-id")) + test.NoError(t, im.Delete(ctx, "non-existent-id")) }) // --- Wipe --- @@ -316,8 +316,8 @@ func TestElasticsearch_Container(T *testing.T) { im := &indexManager[example]{} - assert.Error(t, im.Wipe(t.Context())) - assert.Equal(t, "unimplemented", im.Wipe(t.Context()).Error()) + test.Error(t, im.Wipe(t.Context())) + test.EqOp(t, "unimplemented", im.Wipe(t.Context()).Error()) }) } @@ -334,8 +334,8 @@ func TestIndexManager_ensureIndices_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) err := im.ensureIndices(context.Background()) - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + test.Error(t, err) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) }) T.Run("with unreachable server", func(t *testing.T) { @@ -349,7 +349,7 @@ func TestIndexManager_ensureIndices_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) err := im.ensureIndices(context.Background()) - assert.Error(t, err) + test.Error(t, err) }) } @@ -377,7 +377,7 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { im := buildTestIndexManagerWithServer(t, server, cb) err := im.ensureIndices(context.Background()) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("index does not exist and create succeeds", func(t *testing.T) { @@ -406,7 +406,7 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { im := buildTestIndexManagerWithServer(t, server, cb) err := im.ensureIndices(context.Background()) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("index does not exist and create fails", func(t *testing.T) { @@ -439,7 +439,7 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { im := buildTestIndexManagerWithServer(t, server, cb) err := im.ensureIndices(context.Background()) - assert.Error(t, err) + test.Error(t, err) }) } @@ -454,8 +454,8 @@ func Test_provideElasticsearchClient_Unit(T *testing.T) { } client, err := provideElasticsearchClient(cfg) - assert.NoError(t, err) - assert.NotNil(t, client) + test.NoError(t, err) + test.NotNil(t, client) }) T.Run("with credentials", func(t *testing.T) { @@ -468,8 +468,8 @@ func Test_provideElasticsearchClient_Unit(T *testing.T) { } client, err := provideElasticsearchClient(cfg) - assert.NoError(t, err) - assert.NotNil(t, client) + test.NoError(t, err) + test.NotNil(t, client) }) } @@ -490,7 +490,7 @@ func Test_elasticsearchIsReadyToInit_Unit(T *testing.T) { // err != nil && res != nil && !res.IsError() which won't match when res is nil, // so it falls through to the else branch and returns true. // This is actually a bug in the code but we test the actual behavior. - assert.True(t, ready) + test.True(t, ready) }) T.Run("returns true with reachable server", func(t *testing.T) { @@ -510,7 +510,7 @@ func Test_elasticsearchIsReadyToInit_Unit(T *testing.T) { logger := logging.NewNoopLogger() ready := elasticsearchIsReadyToInit(context.Background(), cfg, logger, 3) - assert.True(t, ready) + test.True(t, ready) }) } @@ -553,8 +553,8 @@ func TestProvideIndexManager_Unit(T *testing.T) { } im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) - assert.NoError(t, err) - assert.NotNil(t, im) + test.NoError(t, err) + test.NotNil(t, im) }) T.Run("fails when ensureIndices fails", func(t *testing.T) { @@ -603,7 +603,7 @@ func TestProvideIndexManager_Unit(T *testing.T) { } im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) - assert.Error(t, err) - assert.Nil(t, im) + test.Error(t, err) + test.Nil(t, im) }) } diff --git a/search/text/elasticsearch/index_test.go b/search/text/elasticsearch/index_test.go index cfb1983..7a0d620 100644 --- a/search/text/elasticsearch/index_test.go +++ b/search/text/elasticsearch/index_test.go @@ -13,8 +13,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/elastic/go-elasticsearch/v8" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type example struct { @@ -77,8 +77,8 @@ func TestIndexManager_Index_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + test.Error(t, err) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) }) T.Run("with unmarshalable value", func(t *testing.T) { @@ -91,7 +91,7 @@ func TestIndexManager_Index_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) err := im.Index(context.Background(), "id", make(chan int)) - assert.Error(t, err) + test.Error(t, err) }) T.Run("with unreachable server", func(t *testing.T) { @@ -105,7 +105,7 @@ func TestIndexManager_Index_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) - assert.Error(t, err) + test.Error(t, err) }) } @@ -131,7 +131,7 @@ func TestIndexManager_Index_Unit(T *testing.T) { im := buildTestIndexManagerWithServer(t, server, cb) err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) - assert.NoError(t, err) + test.NoError(t, err) }) T.Run("with non-success status code", func(t *testing.T) { @@ -153,7 +153,7 @@ func TestIndexManager_Index_Unit(T *testing.T) { im := buildTestIndexManagerWithServer(t, server, cb) err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) - assert.Error(t, err) + test.Error(t, err) }) } @@ -170,9 +170,9 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) results, err := im.Search(context.Background(), "query") - assert.Error(t, err) - assert.Nil(t, results) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + test.Error(t, err) + test.Nil(t, results) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) }) T.Run("with empty query", func(t *testing.T) { @@ -185,9 +185,9 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) results, err := im.Search(context.Background(), "") - assert.Error(t, err) - assert.Nil(t, results) - assert.Equal(t, ErrEmptyQueryProvided, err) + test.Error(t, err) + test.Nil(t, results) + test.ErrorIs(t, err, ErrEmptyQueryProvided) }) T.Run("with unreachable server", func(t *testing.T) { @@ -201,8 +201,8 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) results, err := im.Search(context.Background(), "test query") - assert.Error(t, err) - assert.Nil(t, results) + test.Error(t, err) + test.Nil(t, results) }) } @@ -228,10 +228,10 @@ func TestIndexManager_Search_Unit(T *testing.T) { im := buildTestIndexManagerWithServer(t, server, cb) results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - require.Len(t, results, 1) - assert.Equal(t, "123", results[0].ID) - assert.Equal(t, "test", results[0].Name) + test.NoError(t, err) + must.SliceLen(t, 1, results) + test.EqOp(t, "123", results[0].ID) + test.EqOp(t, "test", results[0].Name) }) T.Run("with error response", func(t *testing.T) { @@ -257,8 +257,8 @@ func TestIndexManager_Search_Unit(T *testing.T) { // does exercise the IsError() branch and calls circuitBreaker.Failed(), // but ultimately returns nil error due to the defer clobbering it. results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - assert.Nil(t, results) + test.NoError(t, err) + test.Nil(t, results) }) T.Run("with invalid JSON in success response", func(t *testing.T) { @@ -282,8 +282,8 @@ func TestIndexManager_Search_Unit(T *testing.T) { // NOTE: same issue as error response test - the deferred res.Body.Close() // overwrites the named return 'err' with nil. results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - assert.Nil(t, results) + test.NoError(t, err) + test.Nil(t, results) }) } @@ -311,8 +311,8 @@ func TestIndexManager_Search_ErrorResponseDecodeFailure_Unit(T *testing.T) { // NOTE: the named return 'err' from the deferred res.Body.Close() clobbers // the error, so this returns nil error despite the decode failure. results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - assert.Nil(t, results) + test.NoError(t, err) + test.Nil(t, results) }) } @@ -340,8 +340,8 @@ func TestIndexManager_Search_SourceUnmarshalError_Unit(T *testing.T) { // NOTE: the named return 'err' from the deferred res.Body.Close() clobbers // the error, so this returns nil error despite the unmarshal failure. results, err := im.Search(context.Background(), "test") - assert.NoError(t, err) - assert.Nil(t, results) + test.NoError(t, err) + test.Nil(t, results) }) } @@ -358,8 +358,8 @@ func TestIndexManager_Delete_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) err := im.Delete(context.Background(), "id") - assert.Error(t, err) - assert.Equal(t, circuitbreaking.ErrCircuitBroken, err) + test.Error(t, err) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) }) T.Run("with unreachable server", func(t *testing.T) { @@ -373,7 +373,7 @@ func TestIndexManager_Delete_CircuitBroken(T *testing.T) { im := buildTestIndexManagerForUnit(t, cb) err := im.Delete(context.Background(), "some-id") - assert.Error(t, err) + test.Error(t, err) }) } @@ -399,7 +399,7 @@ func TestIndexManager_Delete_Unit(T *testing.T) { im := buildTestIndexManagerWithServer(t, server, cb) err := im.Delete(context.Background(), "123") - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -412,7 +412,7 @@ func TestIndexManager_Wipe_Unit(T *testing.T) { im := &indexManager[example]{} err := im.Wipe(context.Background()) - assert.Error(t, err) - assert.Equal(t, "unimplemented", err.Error()) + test.Error(t, err) + test.EqOp(t, "unimplemented", err.Error()) }) } diff --git a/search/text/indexing/do_test.go b/search/text/indexing/do_test.go index 6df1cf7..9683856 100644 --- a/search/text/indexing/do_test.go +++ b/search/text/indexing/do_test.go @@ -14,8 +14,7 @@ import ( "github.com/samber/do/v2" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" otelmetric "go.opentelemetry.io/otel/metric" ) @@ -56,8 +55,8 @@ func TestRegisterIndexScheduler(T *testing.T) { RegisterIndexScheduler(i) scheduler, err := do.Invoke[*IndexScheduler](i) - require.NoError(t, err) - assert.NotNil(t, scheduler) + must.NoError(t, err) + test.NotNil(t, scheduler) test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) test.SliceLen(t, 1, messageQueueProvider.ProvidePublisherCalls()) diff --git a/search/text/indexing/indexer_test.go b/search/text/indexing/indexer_test.go index 41b54bd..726ca9c 100644 --- a/search/text/indexing/indexer_test.go +++ b/search/text/indexing/indexer_test.go @@ -16,8 +16,7 @@ import ( textsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/text" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -59,10 +58,10 @@ func TestNewIndexScheduler(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) - assert.NoError(t, err) - assert.NotNil(t, scheduler) - assert.Equal(t, []string{"test_type"}, scheduler.allIndexTypes) - assert.Len(t, scheduler.indexFunctions, 1) + test.NoError(t, err) + test.NotNil(t, scheduler) + test.Eq(t, []string{"test_type"}, scheduler.allIndexTypes) + test.MapLen(t, 1, scheduler.indexFunctions) test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) test.SliceLen(t, 1, messageQueueProvider.ProvidePublisherCalls()) @@ -96,11 +95,11 @@ func TestNewIndexScheduler(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, nil) - assert.NoError(t, err) - assert.NotNil(t, scheduler) - assert.Empty(t, scheduler.allIndexTypes) - assert.NotNil(t, scheduler.indexFunctions) - assert.Len(t, scheduler.indexFunctions, 0) + test.NoError(t, err) + test.NotNil(t, scheduler) + test.SliceEmpty(t, scheduler.allIndexTypes) + test.NotNil(t, scheduler.indexFunctions) + test.MapLen(t, 0, scheduler.indexFunctions) test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) test.SliceLen(t, 1, messageQueueProvider.ProvidePublisherCalls()) @@ -123,9 +122,9 @@ func TestNewIndexScheduler(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, nil) - assert.Error(t, err) - assert.Nil(t, scheduler) - assert.Contains(t, err.Error(), "metrics error") + test.Error(t, err) + test.Nil(t, scheduler) + test.StrContains(t, err.Error(), "metrics error") test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) test.SliceLen(t, 0, messageQueueProvider.ProvidePublisherCalls()) @@ -157,9 +156,9 @@ func TestNewIndexScheduler(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, nil) - assert.Error(t, err) - assert.Nil(t, scheduler) - assert.Contains(t, err.Error(), "message queue error") + test.Error(t, err) + test.Nil(t, scheduler) + test.StrContains(t, err.Error(), "message queue error") test.SliceLen(t, 1, metricsProvider.NewInt64CounterCalls()) test.SliceLen(t, 1, messageQueueProvider.ProvidePublisherCalls()) @@ -191,7 +190,7 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { publisher := &mockpublishers.PublisherMock{ PublishFunc: func(_ context.Context, data any) error { req, ok := data.(*textsearch.IndexRequest) - require.True(t, ok) + must.True(t, ok) test.EqOp(t, "test_type", req.IndexType) return nil }, @@ -210,11 +209,11 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) - require.NoError(t, err) + must.NoError(t, err) // Since we only have one index type, it will always be chosen err = scheduler.IndexTypes(ctx) - assert.NoError(t, err) + test.NoError(t, err) publishedIDs := collectPublishedRowIDs(t, publisher.PublishCalls()) test.SliceContainsAll(t, publishedIDs, []string{"id1", "id2", "id3"}) @@ -258,12 +257,12 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) - require.NoError(t, err) + must.NoError(t, err) // No publisher calls expected for empty results // But metrics counter is still called with 0 err = scheduler.IndexTypes(ctx) - assert.NoError(t, err) + test.NoError(t, err) test.SliceLen(t, 0, publisher.PublishCalls()) @@ -306,11 +305,11 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) - require.NoError(t, err) + must.NoError(t, err) // sql.ErrNoRows should be handled gracefully and return nil err = scheduler.IndexTypes(ctx) - assert.NoError(t, err) + test.NoError(t, err) test.SliceLen(t, 0, publisher.PublishCalls()) }) @@ -349,11 +348,11 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) - require.NoError(t, err) + must.NoError(t, err) err = scheduler.IndexTypes(ctx) - assert.Error(t, err) - assert.Contains(t, err.Error(), "database connection failed") + test.Error(t, err) + test.StrContains(t, err.Error(), "database connection failed") test.SliceLen(t, 0, publisher.PublishCalls()) }) @@ -386,15 +385,15 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { // Create scheduler with empty index functions scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, map[string]Function{}) - require.NoError(t, err) + must.NoError(t, err) // This should not happen in normal operation since random.Element would return empty string // But we can test the error handling by directly calling with a non-existent type scheduler.allIndexTypes = []string{"unknown_type"} err = scheduler.IndexTypes(ctx) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown index type unknown_type") + test.Error(t, err) + test.StrContains(t, err.Error(), "unknown index type unknown_type") test.SliceLen(t, 0, publisher.PublishCalls()) }) @@ -426,7 +425,7 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { publisher := &mockpublishers.PublisherMock{ PublishFunc: func(_ context.Context, data any) error { req, ok := data.(*textsearch.IndexRequest) - require.True(t, ok) + must.True(t, ok) return publishResults[req.RowID] }, } @@ -444,10 +443,10 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) - require.NoError(t, err) + must.NoError(t, err) err = scheduler.IndexTypes(ctx) - assert.NoError(t, err) // Partial failures don't cause the method to return an error + test.NoError(t, err) // Partial failures don't cause the method to return an error publishedIDs := collectPublishedRowIDs(t, publisher.PublishCalls()) test.SliceContainsAll(t, publishedIDs, []string{"id1", "id2", "id3"}) @@ -496,10 +495,10 @@ func TestIndexScheduler_IndexTypes(T *testing.T) { } scheduler, err := NewIndexScheduler(ctx, logger, tracerProvider, metricsProvider, messageQueueProvider, testQueuesConfig, indexFunctions) - require.NoError(t, err) + must.NoError(t, err) err = scheduler.IndexTypes(ctx) - assert.NoError(t, err) // Even all failures don't cause the method to return an error + test.NoError(t, err) // Even all failures don't cause the method to return an error test.SliceLen(t, 2, publisher.PublishCalls()) @@ -519,7 +518,7 @@ func collectPublishedRowIDs(t *testing.T, calls []struct { ids := make([]string, 0, len(calls)) for i := range calls { req, ok := calls[i].Data.(*textsearch.IndexRequest) - require.True(t, ok) + must.True(t, ok) ids = append(ids, req.RowID) } return ids diff --git a/search/text/noop/noop_test.go b/search/text/noop/noop_test.go index c88c2ea..7620503 100644 --- a/search/text/noop/noop_test.go +++ b/search/text/noop/noop_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestIndexManager_Search(T *testing.T) { @@ -17,9 +17,9 @@ func TestIndexManager_Search(T *testing.T) { m := NewIndexManager[string]() results, err := m.Search(context.Background(), "query") - require.NoError(t, err) - assert.Empty(t, results) - assert.NotNil(t, results) + must.NoError(t, err) + test.SliceEmpty(t, results) + test.NotNil(t, results) }) } @@ -32,7 +32,7 @@ func TestIndexManager_Index(T *testing.T) { m := NewIndexManager[string]() err := m.Index(context.Background(), "id", "value") - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -45,7 +45,7 @@ func TestIndexManager_Delete(T *testing.T) { m := NewIndexManager[string]() err := m.Delete(context.Background(), "id") - assert.NoError(t, err) + test.NoError(t, err) }) } @@ -58,6 +58,6 @@ func TestIndexManager_Wipe(T *testing.T) { m := NewIndexManager[string]() err := m.Wipe(context.Background()) - assert.NoError(t, err) + test.NoError(t, err) }) } diff --git a/search/vector/config/config_test.go b/search/vector/config/config_test.go index 88205ea..d906608 100644 --- a/search/vector/config/config_test.go +++ b/search/vector/config/config_test.go @@ -19,8 +19,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector/qdrant" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -43,7 +42,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("qdrant provider", func(t *testing.T) { @@ -58,35 +57,35 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("invalid provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: "made-up"} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("pgvector provider without config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: PgvectorProvider} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("qdrant provider without config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: QdrantProvider} - assert.Error(t, cfg.ValidateWithContext(t.Context())) + test.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("empty provider is valid (defaults to noop)", func(t *testing.T) { t.Parallel() cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(t.Context())) + test.NoError(t, cfg.ValidateWithContext(t.Context())) }) } @@ -105,8 +104,8 @@ func TestConfig_ProvideIndex(T *testing.T) { nil, "idx", ) - assert.ErrorIs(t, err, vectorsearch.ErrNilConfig) - assert.Nil(t, idx) + test.ErrorIs(t, err, vectorsearch.ErrNilConfig) + test.Nil(t, idx) }) T.Run("unknown provider returns noop", func(t *testing.T) { @@ -121,9 +120,9 @@ func TestConfig_ProvideIndex(T *testing.T) { nil, "idx", ) - require.NoError(t, err) - require.NotNil(t, idx) - assert.NoError(t, idx.Wipe(t.Context())) + must.NoError(t, err) + must.NotNil(t, idx) + test.NoError(t, idx.Wipe(t.Context())) }) T.Run("empty provider returns noop", func(t *testing.T) { @@ -138,8 +137,8 @@ func TestConfig_ProvideIndex(T *testing.T) { nil, "idx", ) - require.NoError(t, err) - require.NotNil(t, idx) + must.NoError(t, err) + must.NotNil(t, idx) }) T.Run("provider with whitespace returns noop", func(t *testing.T) { @@ -154,8 +153,8 @@ func TestConfig_ProvideIndex(T *testing.T) { nil, "idx", ) - require.NoError(t, err) - require.NotNil(t, idx) + must.NoError(t, err) + must.NotNil(t, idx) }) T.Run("pgvector provider with nil db returns error", func(t *testing.T) { @@ -178,8 +177,8 @@ func TestConfig_ProvideIndex(T *testing.T) { nil, "idx", ) - assert.Error(t, err) - assert.Nil(t, idx) + test.Error(t, err) + test.Nil(t, idx) }) T.Run("qdrant provider via httptest server", func(t *testing.T) { @@ -218,8 +217,8 @@ func TestConfig_ProvideIndex(T *testing.T) { nil, "stub", ) - require.NoError(t, err) - require.NotNil(t, idx) + must.NoError(t, err) + must.NotNil(t, idx) }) T.Run("circuit breaker init failure", func(t *testing.T) { @@ -251,9 +250,9 @@ func TestConfig_ProvideIndex(T *testing.T) { nil, "idx", ) - assert.Error(t, err) - assert.Nil(t, idx) - assert.Contains(t, err.Error(), "circuit breaker") + test.Error(t, err) + test.Nil(t, idx) + test.StrContains(t, err.Error(), "circuit breaker") test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) diff --git a/search/vector/noop/noop_test.go b/search/vector/noop/noop_test.go index b603f5c..88ceaff 100644 --- a/search/vector/noop/noop_test.go +++ b/search/vector/noop/noop_test.go @@ -5,8 +5,8 @@ import ( vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) type example struct { @@ -20,7 +20,7 @@ func TestNewIndex(T *testing.T) { t.Parallel() idx := NewIndex[example]() - assert.NotNil(t, idx) + test.NotNil(t, idx) }) } @@ -31,7 +31,7 @@ func TestIndexManager_Upsert(T *testing.T) { t.Parallel() idx := NewIndex[example]() - require.NoError(t, idx.Upsert(t.Context(), vectorsearch.Vector[example]{ + must.NoError(t, idx.Upsert(t.Context(), vectorsearch.Vector[example]{ ID: "abc", Embedding: []float32{0.1, 0.2, 0.3}, Metadata: &example{Name: "doc"}, @@ -46,7 +46,7 @@ func TestIndexManager_Delete(T *testing.T) { t.Parallel() idx := NewIndex[example]() - require.NoError(t, idx.Delete(t.Context(), "abc", "def")) + must.NoError(t, idx.Delete(t.Context(), "abc", "def")) }) } @@ -57,7 +57,7 @@ func TestIndexManager_Wipe(T *testing.T) { t.Parallel() idx := NewIndex[example]() - require.NoError(t, idx.Wipe(t.Context())) + must.NoError(t, idx.Wipe(t.Context())) }) } @@ -72,7 +72,7 @@ func TestIndexManager_Query(T *testing.T) { Embedding: []float32{0.1, 0.2, 0.3}, TopK: 10, }) - require.NoError(t, err) - assert.Empty(t, results) + must.NoError(t, err) + test.SliceEmpty(t, results) }) } diff --git a/search/vector/pgvector/pgvector_test.go b/search/vector/pgvector/pgvector_test.go index 244e67f..759a47f 100644 --- a/search/vector/pgvector/pgvector_test.go +++ b/search/vector/pgvector/pgvector_test.go @@ -17,8 +17,8 @@ import ( vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" _ "github.com/jackc/pgx/v5/stdlib" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "github.com/testcontainers/testcontainers-go" postgrescontainer "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" @@ -80,15 +80,15 @@ func buildContainerBackedPgvector(t *testing.T) (client *testDBClient, shutdown postgrescontainer.WithPassword("vectortest"), testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForLog("database system is ready to accept connections").WithOccurrence(2)), ) - require.NoError(t, err) - require.NotNil(t, container) + must.NoError(t, err) + must.NotNil(t, container) connStr, err := container.ConnectionString(ctx, "sslmode=disable") - require.NoError(t, err) + must.NoError(t, err) db, err := sql.Open("pgx", connStr) - require.NoError(t, err) - require.NoError(t, db.PingContext(ctx)) + must.NoError(t, err) + must.NoError(t, db.PingContext(ctx)) return &testDBClient{db: db}, func(ctx context.Context) error { _ = db.Close() @@ -109,8 +109,8 @@ func provideTestIndex(t *testing.T, client database.Client, indexName string, di Metric: distanceMetric, } im, err := ProvideIndex[doc](t.Context(), nil, nil, nil, cfg, client, indexName, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) - require.NotNil(t, im) + must.NoError(t, err) + must.NotNil(t, im) return im } @@ -121,42 +121,42 @@ func TestProvideIndex(T *testing.T) { t.Parallel() _, err := ProvideIndex[doc](t.Context(), nil, nil, nil, nil, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.ErrorIs(t, err, vectorsearch.ErrNilConfig) + must.ErrorIs(t, err, vectorsearch.ErrNilConfig) }) T.Run("nil database", func(t *testing.T) { t.Parallel() _, err := ProvideIndex[doc](t.Context(), nil, nil, nil, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, nil, "idx", cbnoop.NewCircuitBreaker()) - require.ErrorIs(t, err, vectorsearch.ErrNilDatabaseClient) + must.ErrorIs(t, err, vectorsearch.ErrNilDatabaseClient) }) T.Run("invalid dimension", func(t *testing.T) { t.Parallel() _, err := ProvideIndex[doc](t.Context(), nil, nil, nil, &Config{Dimension: 0, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid metric", func(t *testing.T) { t.Parallel() _, err := ProvideIndex[doc](t.Context(), nil, nil, nil, &Config{Dimension: 3, Metric: "weird"}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid index name", func(t *testing.T) { t.Parallel() _, err := ProvideIndex[doc](t.Context(), nil, nil, nil, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "no-dashes", cbnoop.NewCircuitBreaker()) - require.ErrorIs(t, err, ErrInvalidIdentifier) + must.ErrorIs(t, err, ErrInvalidIdentifier) }) T.Run("invalid metadata column", func(t *testing.T) { t.Parallel() _, err := ProvideIndex[doc](t.Context(), nil, nil, nil, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine, MetadataColumn: "weird-col"}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.ErrorIs(t, err, ErrInvalidIdentifier) + must.ErrorIs(t, err, ErrInvalidIdentifier) }) T.Run("error creating upsert counter", func(t *testing.T) { @@ -167,7 +167,7 @@ func TestProvideIndex(T *testing.T) { }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("error creating delete counter", func(t *testing.T) { @@ -179,7 +179,7 @@ func TestProvideIndex(T *testing.T) { }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("error creating wipe counter", func(t *testing.T) { @@ -192,7 +192,7 @@ func TestProvideIndex(T *testing.T) { }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("error creating query counter", func(t *testing.T) { @@ -206,7 +206,7 @@ func TestProvideIndex(T *testing.T) { }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("error creating error counter", func(t *testing.T) { @@ -221,7 +221,7 @@ func TestProvideIndex(T *testing.T) { }) _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("error creating latency histogram", func(t *testing.T) { @@ -237,7 +237,7 @@ func TestProvideIndex(T *testing.T) { } _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) } @@ -248,34 +248,34 @@ func Test_operatorAndOpClass(T *testing.T) { t.Parallel() op, opsClass, err := operatorAndOpClass(vectorsearch.DistanceCosine) - require.NoError(t, err) - assert.Equal(t, "<=>", op) - assert.Equal(t, "vector_cosine_ops", opsClass) + must.NoError(t, err) + test.EqOp(t, "<=>", op) + test.EqOp(t, "vector_cosine_ops", opsClass) }) T.Run("dot product", func(t *testing.T) { t.Parallel() op, opsClass, err := operatorAndOpClass(vectorsearch.DistanceDotProduct) - require.NoError(t, err) - assert.Equal(t, "<#>", op) - assert.Equal(t, "vector_ip_ops", opsClass) + must.NoError(t, err) + test.EqOp(t, "<#>", op) + test.EqOp(t, "vector_ip_ops", opsClass) }) T.Run("euclidean", func(t *testing.T) { t.Parallel() op, opsClass, err := operatorAndOpClass(vectorsearch.DistanceEuclidean) - require.NoError(t, err) - assert.Equal(t, "<->", op) - assert.Equal(t, "vector_l2_ops", opsClass) + must.NoError(t, err) + test.EqOp(t, "<->", op) + test.EqOp(t, "vector_l2_ops", opsClass) }) T.Run("invalid metric", func(t *testing.T) { t.Parallel() _, _, err := operatorAndOpClass("bogus") - require.ErrorIs(t, err, vectorsearch.ErrInvalidMetric) + must.ErrorIs(t, err, vectorsearch.ErrInvalidMetric) }) } @@ -285,19 +285,19 @@ func TestEncodeVector(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.Equal(t, "[0.1,0.2,0.3]", encodeVector([]float32{0.1, 0.2, 0.3})) + test.EqOp(t, "[0.1,0.2,0.3]", encodeVector([]float32{0.1, 0.2, 0.3})) }) T.Run("empty", func(t *testing.T) { t.Parallel() - assert.Equal(t, "[]", encodeVector(nil)) + test.EqOp(t, "[]", encodeVector(nil)) }) T.Run("integer-valued", func(t *testing.T) { t.Parallel() - assert.Equal(t, "[1,2,3]", encodeVector([]float32{1, 2, 3})) + test.EqOp(t, "[1,2,3]", encodeVector([]float32{1, 2, 3})) }) } @@ -306,12 +306,12 @@ func TestQuoteIdent(T *testing.T) { T.Run("simple", func(t *testing.T) { t.Parallel() - assert.Equal(t, `"users"`, quoteIdent("users")) + test.EqOp(t, `"users"`, quoteIdent("users")) }) T.Run("with embedded quote", func(t *testing.T) { t.Parallel() - assert.Equal(t, `"foo""bar"`, quoteIdent(`foo"bar`)) + test.EqOp(t, `"foo""bar"`, quoteIdent(`foo"bar`)) }) } @@ -320,12 +320,12 @@ func TestPgTextArray(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.Equal(t, `{"a","b","c"}`, pgTextArray([]string{"a", "b", "c"})) + test.EqOp(t, `{"a","b","c"}`, pgTextArray([]string{"a", "b", "c"})) }) T.Run("with quotes", func(t *testing.T) { t.Parallel() - assert.Equal(t, `{"a\"b","c"}`, pgTextArray([]string{`a"b`, "c"})) + test.EqOp(t, `{"a\"b","c"}`, pgTextArray([]string{`a"b`, "c"})) }) } @@ -335,44 +335,44 @@ func TestMarshalUnmarshalMetadata(T *testing.T) { T.Run("nil round-trip", func(t *testing.T) { t.Parallel() raw, err := marshalMetadata[doc](nil) - require.NoError(t, err) - assert.Equal(t, []byte(`{}`), raw) + must.NoError(t, err) + test.Eq(t, []byte(`{}`), raw) out, err := unmarshalMetadata[doc](raw) - require.NoError(t, err) - require.NotNil(t, out) + must.NoError(t, err) + must.NotNil(t, out) }) T.Run("populated round-trip", func(t *testing.T) { t.Parallel() original := &doc{Kind: "doc", Title: "hello"} raw, err := marshalMetadata(original) - require.NoError(t, err) + must.NoError(t, err) out, err := unmarshalMetadata[doc](raw) - require.NoError(t, err) - require.NotNil(t, out) - assert.Equal(t, *original, *out) + must.NoError(t, err) + must.NotNil(t, out) + test.Eq(t, *original, *out) }) T.Run("null is treated as nil", func(t *testing.T) { t.Parallel() out, err := unmarshalMetadata[doc]([]byte("null")) - require.NoError(t, err) - assert.Nil(t, out) + must.NoError(t, err) + test.Nil(t, out) }) T.Run("empty is treated as nil", func(t *testing.T) { t.Parallel() out, err := unmarshalMetadata[doc]([]byte{}) - require.NoError(t, err) - assert.Nil(t, out) + must.NoError(t, err) + test.Nil(t, out) }) T.Run("invalid JSON returns error", func(t *testing.T) { t.Parallel() _, err := unmarshalMetadata[doc]([]byte(`{not json`)) - require.Error(t, err) + must.Error(t, err) }) } @@ -381,22 +381,22 @@ func Test_firstWords(T *testing.T) { T.Run("multi-word statement", func(t *testing.T) { t.Parallel() - assert.Equal(t, "CREATE EXTENSION", firstWords("CREATE EXTENSION IF NOT EXISTS vector")) + test.EqOp(t, "CREATE EXTENSION", firstWords("CREATE EXTENSION IF NOT EXISTS vector")) }) T.Run("single word", func(t *testing.T) { t.Parallel() - assert.Equal(t, "TRUNCATE", firstWords("TRUNCATE")) + test.EqOp(t, "TRUNCATE", firstWords("TRUNCATE")) }) T.Run("two words only", func(t *testing.T) { t.Parallel() - assert.Equal(t, "DROP TABLE", firstWords("DROP TABLE")) + test.EqOp(t, "DROP TABLE", firstWords("DROP TABLE")) }) T.Run("leading whitespace is trimmed", func(t *testing.T) { t.Parallel() - assert.Equal(t, "CREATE TABLE", firstWords(" CREATE TABLE foo")) + test.EqOp(t, "CREATE TABLE", firstWords(" CREATE TABLE foo")) }) } @@ -407,28 +407,28 @@ func TestValidateWithContext(T *testing.T) { t.Parallel() var cfg *Config err := cfg.ValidateWithContext(t.Context()) - require.Error(t, err) + must.Error(t, err) }) T.Run("valid config", func(t *testing.T) { t.Parallel() cfg := &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine} err := cfg.ValidateWithContext(t.Context()) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("missing dimension", func(t *testing.T) { t.Parallel() cfg := &Config{Metric: vectorsearch.DistanceCosine} err := cfg.ValidateWithContext(t.Context()) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid metric", func(t *testing.T) { t.Parallel() cfg := &Config{Dimension: 3, Metric: "bogus"} err := cfg.ValidateWithContext(t.Context()) - require.Error(t, err) + must.Error(t, err) }) } @@ -449,7 +449,7 @@ func TestPgvectorIndex_Container(T *testing.T) { ctx := t.Context() idx := provideTestIndex(t, client, "rt_"+identifiers.New(), 3, vectorsearch.DistanceCosine) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}, Metadata: &doc{Kind: "doc", Title: "alpha"}}, vectorsearch.Vector[doc]{ID: "b", Embedding: []float32{0, 1, 0}, Metadata: &doc{Kind: "doc", Title: "beta"}}, vectorsearch.Vector[doc]{ID: "c", Embedding: []float32{0, 0, 1}, Metadata: &doc{Kind: "doc", Title: "gamma"}}, @@ -459,12 +459,12 @@ func TestPgvectorIndex_Container(T *testing.T) { Embedding: []float32{1, 0, 0}, TopK: 3, }) - require.NoError(t, err) - require.Len(t, results, 3) - assert.Equal(t, "a", results[0].ID) - require.NotNil(t, results[0].Metadata) - assert.Equal(t, "alpha", results[0].Metadata.Title) - assert.InDelta(t, 0.0, results[0].Distance, 1e-5) + must.NoError(t, err) + must.SliceLen(t, 3, results) + test.EqOp(t, "a", results[0].ID) + must.NotNil(t, results[0].Metadata) + test.EqOp(t, "alpha", results[0].Metadata.Title) + test.InDelta(t, float32(0.0), results[0].Distance, float32(1e-5)) }) T.Run("Upsert updates existing row", func(t *testing.T) { @@ -472,18 +472,18 @@ func TestPgvectorIndex_Container(T *testing.T) { ctx := t.Context() idx := provideTestIndex(t, client, "upd_"+identifiers.New(), 3, vectorsearch.DistanceCosine) - require.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "x", Embedding: []float32{1, 0, 0}, Metadata: &doc{Title: "first"}})) - require.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "x", Embedding: []float32{0, 1, 0}, Metadata: &doc{Title: "second"}})) + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "x", Embedding: []float32{1, 0, 0}, Metadata: &doc{Title: "first"}})) + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "x", Embedding: []float32{0, 1, 0}, Metadata: &doc{Title: "second"}})) results, err := idx.Query(ctx, vectorsearch.QueryRequest{ Embedding: []float32{0, 1, 0}, TopK: 1, }) - require.NoError(t, err) - require.Len(t, results, 1) - assert.Equal(t, "x", results[0].ID) - require.NotNil(t, results[0].Metadata) - assert.Equal(t, "second", results[0].Metadata.Title) + must.NoError(t, err) + must.SliceLen(t, 1, results) + test.EqOp(t, "x", results[0].ID) + must.NotNil(t, results[0].Metadata) + test.EqOp(t, "second", results[0].Metadata.Title) }) T.Run("TopK is respected", func(t *testing.T) { @@ -491,7 +491,7 @@ func TestPgvectorIndex_Container(T *testing.T) { ctx := t.Context() idx := provideTestIndex(t, client, "topk_"+identifiers.New(), 3, vectorsearch.DistanceCosine) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}}, vectorsearch.Vector[doc]{ID: "b", Embedding: []float32{0, 1, 0}}, vectorsearch.Vector[doc]{ID: "c", Embedding: []float32{0, 0, 1}}, @@ -501,8 +501,8 @@ func TestPgvectorIndex_Container(T *testing.T) { Embedding: []float32{1, 0, 0}, TopK: 2, }) - require.NoError(t, err) - assert.Len(t, results, 2) + must.NoError(t, err) + test.SliceLen(t, 2, results) }) T.Run("filter clause is applied", func(t *testing.T) { @@ -510,7 +510,7 @@ func TestPgvectorIndex_Container(T *testing.T) { ctx := t.Context() idx := provideTestIndex(t, client, "filt_"+identifiers.New(), 3, vectorsearch.DistanceCosine) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}, Metadata: &doc{Kind: "doc"}}, vectorsearch.Vector[doc]{ID: "b", Embedding: []float32{1, 0, 0}, Metadata: &doc{Kind: "image"}}, )) @@ -520,9 +520,9 @@ func TestPgvectorIndex_Container(T *testing.T) { TopK: 10, Filter: "metadata->>'kind' = 'doc'", }) - require.NoError(t, err) - require.Len(t, results, 1) - assert.Equal(t, "a", results[0].ID) + must.NoError(t, err) + must.SliceLen(t, 1, results) + test.EqOp(t, "a", results[0].ID) }) T.Run("Query rejects empty embedding", func(t *testing.T) { @@ -531,7 +531,7 @@ func TestPgvectorIndex_Container(T *testing.T) { idx := provideTestIndex(t, client, "emb_"+identifiers.New(), 3, vectorsearch.DistanceCosine) _, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: nil, TopK: 5}) - require.ErrorIs(t, err, vectorsearch.ErrEmptyEmbedding) + must.ErrorIs(t, err, vectorsearch.ErrEmptyEmbedding) }) T.Run("Query rejects wrong dimension", func(t *testing.T) { @@ -540,7 +540,7 @@ func TestPgvectorIndex_Container(T *testing.T) { idx := provideTestIndex(t, client, "dim_"+identifiers.New(), 3, vectorsearch.DistanceCosine) _, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0}, TopK: 5}) - require.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) + must.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) }) T.Run("Upsert rejects wrong dimension", func(t *testing.T) { @@ -549,7 +549,7 @@ func TestPgvectorIndex_Container(T *testing.T) { idx := provideTestIndex(t, client, "udim_"+identifiers.New(), 3, vectorsearch.DistanceCosine) err := idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0}}) - require.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) + must.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) }) T.Run("Delete removes specific rows", func(t *testing.T) { @@ -557,17 +557,17 @@ func TestPgvectorIndex_Container(T *testing.T) { ctx := t.Context() idx := provideTestIndex(t, client, "del_"+identifiers.New(), 3, vectorsearch.DistanceCosine) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}}, vectorsearch.Vector[doc]{ID: "b", Embedding: []float32{0, 1, 0}}, vectorsearch.Vector[doc]{ID: "c", Embedding: []float32{0, 0, 1}}, )) - require.NoError(t, idx.Delete(ctx, "a", "c")) + must.NoError(t, idx.Delete(ctx, "a", "c")) results, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{0, 1, 0}, TopK: 10}) - require.NoError(t, err) - require.Len(t, results, 1) - assert.Equal(t, "b", results[0].ID) + must.NoError(t, err) + must.SliceLen(t, 1, results) + test.EqOp(t, "b", results[0].ID) }) T.Run("Wipe empties the index", func(t *testing.T) { @@ -575,15 +575,15 @@ func TestPgvectorIndex_Container(T *testing.T) { ctx := t.Context() idx := provideTestIndex(t, client, "wipe_"+identifiers.New(), 3, vectorsearch.DistanceCosine) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}}, vectorsearch.Vector[doc]{ID: "b", Embedding: []float32{0, 1, 0}}, )) - require.NoError(t, idx.Wipe(ctx)) + must.NoError(t, idx.Wipe(ctx)) results, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 10}) - require.NoError(t, err) - assert.Empty(t, results) + must.NoError(t, err) + test.SliceEmpty(t, results) }) T.Run("ProvideIndex is idempotent for the same index", func(t *testing.T) { @@ -592,14 +592,14 @@ func TestPgvectorIndex_Container(T *testing.T) { name := "idem_" + identifiers.New() idx1 := provideTestIndex(t, client, name, 3, vectorsearch.DistanceCosine) idx2 := provideTestIndex(t, client, name, 3, vectorsearch.DistanceCosine) - assert.NotNil(t, idx1) - assert.NotNil(t, idx2) + test.NotNil(t, idx1) + test.NotNil(t, idx2) - require.NoError(t, idx1.Upsert(ctx, vectorsearch.Vector[doc]{ID: "shared", Embedding: []float32{1, 0, 0}})) + must.NoError(t, idx1.Upsert(ctx, vectorsearch.Vector[doc]{ID: "shared", Embedding: []float32{1, 0, 0}})) results, err := idx2.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 1}) - require.NoError(t, err) - require.Len(t, results, 1) - assert.Equal(t, "shared", results[0].ID) + must.NoError(t, err) + must.SliceLen(t, 1, results) + test.EqOp(t, "shared", results[0].ID) }) } diff --git a/search/vector/qdrant/qdrant_test.go b/search/vector/qdrant/qdrant_test.go index b583394..338ef34 100644 --- a/search/vector/qdrant/qdrant_test.go +++ b/search/vector/qdrant/qdrant_test.go @@ -19,8 +19,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/identifiers" vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) @@ -136,7 +136,7 @@ func buildStubIndex(t *testing.T, stub *qdrantStub, cb circuitbreaking.CircuitBr "test", cb, ) - require.NoError(t, err) + must.NoError(t, err) return idx.(*indexManager[doc]) } @@ -158,15 +158,15 @@ func TestMetricToDistance(T *testing.T) { T.Run(string(c.metric), func(t *testing.T) { t.Parallel() got, err := metricToDistance(c.metric) - require.NoError(t, err) - assert.Equal(t, c.want, got) + must.NoError(t, err) + test.EqOp(t, c.want, got) }) } T.Run("invalid", func(t *testing.T) { t.Parallel() _, err := metricToDistance("nonsense") - require.ErrorIs(t, err, vectorsearch.ErrInvalidMetric) + must.ErrorIs(t, err, vectorsearch.ErrInvalidMetric) }) } @@ -176,28 +176,28 @@ func TestStringifyID(T *testing.T) { T.Run("string", func(t *testing.T) { t.Parallel() s, err := stringifyID("abc") - require.NoError(t, err) - assert.Equal(t, "abc", s) + must.NoError(t, err) + test.EqOp(t, "abc", s) }) T.Run("float", func(t *testing.T) { t.Parallel() s, err := stringifyID(float64(42)) - require.NoError(t, err) - assert.Equal(t, "42", s) + must.NoError(t, err) + test.EqOp(t, "42", s) }) T.Run("number", func(t *testing.T) { t.Parallel() s, err := stringifyID(json.Number("17")) - require.NoError(t, err) - assert.Equal(t, "17", s) + must.NoError(t, err) + test.EqOp(t, "17", s) }) T.Run("unsupported", func(t *testing.T) { t.Parallel() _, err := stringifyID(true) - require.Error(t, err) + must.Error(t, err) }) } @@ -207,30 +207,30 @@ func TestUnmarshalPayload(T *testing.T) { T.Run("nil round-trip", func(t *testing.T) { t.Parallel() out, err := unmarshalPayload[doc](nil) - require.NoError(t, err) - assert.Nil(t, out) + must.NoError(t, err) + test.Nil(t, out) }) T.Run("populated", func(t *testing.T) { t.Parallel() out, err := unmarshalPayload[doc](json.RawMessage(`{"kind":"doc","title":"hi"}`)) - require.NoError(t, err) - require.NotNil(t, out) - assert.Equal(t, "doc", out.Kind) - assert.Equal(t, "hi", out.Title) + must.NoError(t, err) + must.NotNil(t, out) + test.EqOp(t, "doc", out.Kind) + test.EqOp(t, "hi", out.Title) }) T.Run("null", func(t *testing.T) { t.Parallel() out, err := unmarshalPayload[doc](json.RawMessage("null")) - require.NoError(t, err) - assert.Nil(t, out) + must.NoError(t, err) + test.Nil(t, out) }) T.Run("invalid JSON", func(t *testing.T) { t.Parallel() _, err := unmarshalPayload[doc](json.RawMessage(`{not valid`)) - require.Error(t, err) + must.Error(t, err) }) } @@ -239,15 +239,15 @@ func TestPayloadFromMetadata(T *testing.T) { T.Run("nil metadata", func(t *testing.T) { t.Parallel() - assert.Nil(t, payloadFromMetadata[doc](nil)) + test.Nil(t, payloadFromMetadata[doc](nil)) }) T.Run("non-nil metadata", func(t *testing.T) { t.Parallel() d := &doc{Kind: "k", Title: "t"} result := payloadFromMetadata(d) - require.NotNil(t, result) - assert.Equal(t, d, result) + must.NotNil(t, result) + test.Eq[any](t, d, result) }) } @@ -257,9 +257,9 @@ func TestWrapStatusError(T *testing.T) { T.Run("wraps ErrUnexpectedStatus", func(t *testing.T) { t.Parallel() err := wrapStatusError(500, []byte("internal error")) - require.ErrorIs(t, err, ErrUnexpectedStatus) - assert.Contains(t, err.Error(), "500") - assert.Contains(t, err.Error(), "internal error") + must.ErrorIs(t, err, ErrUnexpectedStatus) + test.StrContains(t, err.Error(), "500") + test.StrContains(t, err.Error(), "internal error") }) } @@ -269,19 +269,19 @@ func TestCollectionPath(T *testing.T) { T.Run("no suffix", func(t *testing.T) { t.Parallel() im := &indexManager[doc]{baseURL: "http://localhost:6333", collection: "my_col"} - assert.Equal(t, "http://localhost:6333/collections/my_col", im.collectionPath("")) + test.EqOp(t, "http://localhost:6333/collections/my_col", im.collectionPath("")) }) T.Run("with suffix", func(t *testing.T) { t.Parallel() im := &indexManager[doc]{baseURL: "http://localhost:6333", collection: "my_col"} - assert.Equal(t, "http://localhost:6333/collections/my_col/points?wait=true", im.collectionPath("/points?wait=true")) + test.EqOp(t, "http://localhost:6333/collections/my_col/points?wait=true", im.collectionPath("/points?wait=true")) }) T.Run("collection name is URL-escaped", func(t *testing.T) { t.Parallel() im := &indexManager[doc]{baseURL: "http://localhost:6333", collection: "has space"} - assert.Equal(t, "http://localhost:6333/collections/has%20space", im.collectionPath("")) + test.EqOp(t, "http://localhost:6333/collections/has%20space", im.collectionPath("")) }) } @@ -291,7 +291,7 @@ func TestProvideIndex(T *testing.T) { T.Run("nil config", func(t *testing.T) { t.Parallel() _, err := ProvideIndex[doc](t.Context(), nil, nil, nil, nil, "test", cbnoop.NewCircuitBreaker()) - require.ErrorIs(t, err, vectorsearch.ErrNilConfig) + must.ErrorIs(t, err, vectorsearch.ErrNilConfig) }) T.Run("empty collection", func(t *testing.T) { @@ -301,7 +301,7 @@ func TestProvideIndex(T *testing.T) { Dimension: 3, Metric: vectorsearch.DistanceCosine, }, "", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid metric", func(t *testing.T) { @@ -311,7 +311,7 @@ func TestProvideIndex(T *testing.T) { Dimension: 3, Metric: "weird", }, "test", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid dimension", func(t *testing.T) { @@ -321,7 +321,7 @@ func TestProvideIndex(T *testing.T) { Dimension: 0, Metric: vectorsearch.DistanceCosine, }, "test", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid config missing base URL", func(t *testing.T) { @@ -330,7 +330,7 @@ func TestProvideIndex(T *testing.T) { Dimension: 3, Metric: vectorsearch.DistanceCosine, }, "test", cbnoop.NewCircuitBreaker()) - require.Error(t, err) + must.Error(t, err) }) T.Run("collection already exists", func(t *testing.T) { @@ -346,8 +346,8 @@ func TestProvideIndex(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) - require.NotNil(t, idx) + must.NoError(t, err) + must.NotNil(t, idx) }) T.Run("ensureCollection GET fails with unexpected status", func(t *testing.T) { @@ -362,7 +362,7 @@ func TestProvideIndex(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.Error(t, err) + must.Error(t, err) }) T.Run("ensureCollection PUT fails", func(t *testing.T) { @@ -380,7 +380,7 @@ func TestProvideIndex(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.Error(t, err) + must.Error(t, err) }) T.Run("ensureCollection unreachable server", func(t *testing.T) { @@ -391,7 +391,7 @@ func TestProvideIndex(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.Error(t, err) + must.Error(t, err) }) T.Run("default timeout when zero", func(t *testing.T) { @@ -406,8 +406,8 @@ func TestProvideIndex(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) - require.NotNil(t, idx) + must.NoError(t, err) + must.NotNil(t, idx) }) T.Run("sets api key header", func(t *testing.T) { @@ -426,8 +426,8 @@ func TestProvideIndex(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) - assert.Equal(t, "secret", gotAPIKey) + must.NoError(t, err) + test.EqOp(t, "secret", gotAPIKey) }) } @@ -464,16 +464,16 @@ func TestProvideIndex_StubsCollectionCreation(T *testing.T) { "stub", cbnoop.NewCircuitBreaker(), ) - require.NoError(T, err) - require.NotNil(T, idx) + must.NoError(T, err) + must.NotNil(T, idx) - assert.Equal(T, http.MethodPut, gotMethod) - assert.True(T, strings.HasSuffix(gotPath, "/collections/stub")) - require.NotNil(T, gotBody) + test.EqOp(T, http.MethodPut, gotMethod) + test.True(T, strings.HasSuffix(gotPath, "/collections/stub")) + must.NotNil(T, gotBody) vectors, ok := gotBody["vectors"].(map[string]any) - require.True(T, ok) - assert.Equal(T, float64(3), vectors["size"]) - assert.Equal(T, "Cosine", vectors["distance"]) + must.True(T, ok) + test.Eq[any](T, float64(3), vectors["size"]) + test.Eq[any](T, "Cosine", vectors["distance"]) } func TestUpsert(T *testing.T) { @@ -482,7 +482,7 @@ func TestUpsert(T *testing.T) { T.Run("empty vectors is a no-op", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) - require.NoError(t, idx.Upsert(t.Context())) + must.NoError(t, idx.Upsert(t.Context())) }) T.Run("circuit breaker broken", func(t *testing.T) { @@ -494,30 +494,30 @@ func TestUpsert(T *testing.T) { idx := buildStubIndex(t, &qdrantStub{}, cb) err := idx.Upsert(t.Context(), vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}}) - require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + must.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - require.Len(t, cb.CannotProceedCalls(), 1) + must.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("rejects empty ID", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) err := idx.Upsert(t.Context(), vectorsearch.Vector[doc]{ID: "", Embedding: []float32{1, 0, 0}}) - require.ErrorIs(t, err, platformerrors.ErrInvalidIDProvided) + must.ErrorIs(t, err, platformerrors.ErrInvalidIDProvided) }) T.Run("rejects empty embedding", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) err := idx.Upsert(t.Context(), vectorsearch.Vector[doc]{ID: "a", Embedding: nil}) - require.ErrorIs(t, err, vectorsearch.ErrEmptyEmbedding) + must.ErrorIs(t, err, vectorsearch.ErrEmptyEmbedding) }) T.Run("rejects wrong dimension", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) err := idx.Upsert(t.Context(), vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0}}) - require.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) + must.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) }) T.Run("successful upsert", func(t *testing.T) { @@ -527,15 +527,15 @@ func TestUpsert(T *testing.T) { vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}, Metadata: &doc{Kind: "doc", Title: "alpha"}}, vectorsearch.Vector[doc]{ID: "b", Embedding: []float32{0, 1, 0}}, ) - require.NoError(t, err) + must.NoError(t, err) }) T.Run("server returns error status", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{pointsPutStatus: http.StatusInternalServerError}, nil) err := idx.Upsert(t.Context(), vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}}) - require.Error(t, err) - require.ErrorIs(t, err, ErrUnexpectedStatus) + must.Error(t, err) + must.ErrorIs(t, err, ErrUnexpectedStatus) }) T.Run("unreachable server", func(t *testing.T) { @@ -548,12 +548,12 @@ func TestUpsert(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) + must.NoError(t, err) // Close the server to simulate unreachable srv.Close() err = idx.Upsert(t.Context(), vectorsearch.Vector[doc]{ID: "a", Embedding: []float32{1, 0, 0}}) - require.Error(t, err) + must.Error(t, err) }) } @@ -563,7 +563,7 @@ func TestDelete(T *testing.T) { T.Run("empty ids is a no-op", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) - require.NoError(t, idx.Delete(t.Context())) + must.NoError(t, idx.Delete(t.Context())) }) T.Run("circuit breaker broken", func(t *testing.T) { @@ -575,23 +575,23 @@ func TestDelete(T *testing.T) { idx := buildStubIndex(t, &qdrantStub{}, cb) err := idx.Delete(t.Context(), "some-id") - require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + must.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - require.Len(t, cb.CannotProceedCalls(), 1) + must.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("successful delete", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) - require.NoError(t, idx.Delete(t.Context(), "id1", "id2")) + must.NoError(t, idx.Delete(t.Context(), "id1", "id2")) }) T.Run("server returns error status", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{pointsDeleteStatus: http.StatusInternalServerError}, nil) err := idx.Delete(t.Context(), "id1") - require.Error(t, err) - require.ErrorIs(t, err, ErrUnexpectedStatus) + must.Error(t, err) + must.ErrorIs(t, err, ErrUnexpectedStatus) }) T.Run("unreachable server", func(t *testing.T) { @@ -604,11 +604,11 @@ func TestDelete(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) + must.NoError(t, err) srv.Close() err = idx.Delete(t.Context(), "id1") - require.Error(t, err) + must.Error(t, err) }) } @@ -624,29 +624,29 @@ func TestWipe(T *testing.T) { idx := buildStubIndex(t, &qdrantStub{}, cb) err := idx.Wipe(t.Context()) - require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + must.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - require.Len(t, cb.CannotProceedCalls(), 1) + must.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("successful wipe", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) - require.NoError(t, idx.Wipe(t.Context())) + must.NoError(t, idx.Wipe(t.Context())) }) T.Run("delete collection fails", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{collectionDeleteStatus: http.StatusForbidden}, nil) err := idx.Wipe(t.Context()) - require.Error(t, err) - require.ErrorIs(t, err, ErrUnexpectedStatus) + must.Error(t, err) + must.ErrorIs(t, err, ErrUnexpectedStatus) }) T.Run("delete returns 404 still succeeds", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{collectionDeleteStatus: http.StatusNotFound}, nil) - require.NoError(t, idx.Wipe(t.Context())) + must.NoError(t, idx.Wipe(t.Context())) }) T.Run("recreate collection fails after delete", func(t *testing.T) { @@ -684,10 +684,10 @@ func TestWipe(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) + must.NoError(t, err) err = idx.Wipe(t.Context()) - require.Error(t, err) + must.Error(t, err) }) T.Run("unreachable server", func(t *testing.T) { @@ -700,11 +700,11 @@ func TestWipe(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) + must.NoError(t, err) srv.Close() err = idx.Wipe(t.Context()) - require.Error(t, err) + must.Error(t, err) }) } @@ -715,14 +715,14 @@ func TestQuery(T *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) _, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: nil, TopK: 5}) - require.ErrorIs(t, err, vectorsearch.ErrEmptyEmbedding) + must.ErrorIs(t, err, vectorsearch.ErrEmptyEmbedding) }) T.Run("rejects wrong dimension", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{}, nil) _, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0}, TopK: 5}) - require.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) + must.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) }) T.Run("circuit breaker broken", func(t *testing.T) { @@ -734,9 +734,9 @@ func TestQuery(T *testing.T) { idx := buildStubIndex(t, &qdrantStub{}, cb) _, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5}) - require.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + must.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - require.Len(t, cb.CannotProceedCalls(), 1) + must.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("defaults TopK to 10", func(t *testing.T) { @@ -761,12 +761,12 @@ func TestQuery(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) + must.NoError(t, err) _, err = idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 0}) - require.NoError(t, err) - require.NotNil(t, gotBody) - assert.Equal(t, float64(10), gotBody["limit"]) + must.NoError(t, err) + must.NotNil(t, gotBody) + test.Eq[any](t, float64(10), gotBody["limit"]) }) T.Run("successful query returns results", func(t *testing.T) { @@ -775,17 +775,17 @@ func TestQuery(T *testing.T) { idx := buildStubIndex(t, &qdrantStub{pointsSearchBody: searchResp}, nil) results, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5}) - require.NoError(t, err) - require.Len(t, results, 2) + must.NoError(t, err) + must.SliceLen(t, 2, results) - assert.Equal(t, "abc", results[0].ID) - assert.InDelta(t, 0.95, float64(results[0].Distance), 0.001) - require.NotNil(t, results[0].Metadata) - assert.Equal(t, "doc", results[0].Metadata.Kind) - assert.Equal(t, "hello", results[0].Metadata.Title) + test.EqOp(t, "abc", results[0].ID) + test.InDelta(t, 0.95, float64(results[0].Distance), 0.001) + must.NotNil(t, results[0].Metadata) + test.EqOp(t, "doc", results[0].Metadata.Kind) + test.EqOp(t, "hello", results[0].Metadata.Title) - assert.Equal(t, "def", results[1].ID) - assert.Nil(t, results[1].Metadata) + test.EqOp(t, "def", results[1].ID) + test.Nil(t, results[1].Metadata) }) T.Run("query with numeric ID", func(t *testing.T) { @@ -794,9 +794,9 @@ func TestQuery(T *testing.T) { idx := buildStubIndex(t, &qdrantStub{pointsSearchBody: searchResp}, nil) results, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 1}) - require.NoError(t, err) - require.Len(t, results, 1) - assert.Equal(t, "42", results[0].ID) + must.NoError(t, err) + must.SliceLen(t, 1, results) + test.EqOp(t, "42", results[0].ID) }) T.Run("query with filter", func(t *testing.T) { @@ -820,28 +820,28 @@ func TestQuery(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) + must.NoError(t, err) filter := map[string]any{"must": []any{map[string]any{"key": "kind", "match": map[string]any{"value": "doc"}}}} _, err = idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5, Filter: filter}) - require.NoError(t, err) - require.NotNil(t, gotBody) - assert.NotNil(t, gotBody["filter"]) + must.NoError(t, err) + must.NotNil(t, gotBody) + test.NotNil(t, gotBody["filter"]) }) T.Run("server returns error status", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{pointsSearchStatus: http.StatusInternalServerError}, nil) _, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5}) - require.Error(t, err) - require.ErrorIs(t, err, ErrUnexpectedStatus) + must.Error(t, err) + must.ErrorIs(t, err, ErrUnexpectedStatus) }) T.Run("invalid JSON response", func(t *testing.T) { t.Parallel() idx := buildStubIndex(t, &qdrantStub{pointsSearchBody: `{not json`}, nil) _, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5}) - require.Error(t, err) + must.Error(t, err) }) T.Run("invalid payload in response", func(t *testing.T) { @@ -850,7 +850,7 @@ func TestQuery(T *testing.T) { searchResp := `{"result":[{"id":"x","score":0.5,"payload":"not-a-doc"}]}` idx := buildStubIndex(t, &qdrantStub{pointsSearchBody: searchResp}, nil) _, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5}) - require.Error(t, err) + must.Error(t, err) }) T.Run("unsupported ID type in response", func(t *testing.T) { @@ -858,7 +858,7 @@ func TestQuery(T *testing.T) { searchResp := `{"result":[{"id":true,"score":0.5,"payload":null}]}` idx := buildStubIndex(t, &qdrantStub{pointsSearchBody: searchResp}, nil) _, err := idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5}) - require.Error(t, err) + must.Error(t, err) }) T.Run("unreachable server", func(t *testing.T) { @@ -871,11 +871,11 @@ func TestQuery(T *testing.T) { "test", cbnoop.NewCircuitBreaker(), ) - require.NoError(t, err) + must.NoError(t, err) srv.Close() _, err = idx.Query(t.Context(), vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 5}) - require.Error(t, err) + must.Error(t, err) }) } @@ -886,7 +886,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { t.Parallel() var cfg *Config err := cfg.ValidateWithContext(t.Context()) - require.ErrorIs(t, err, platformerrors.ErrNilInputParameter) + must.ErrorIs(t, err, platformerrors.ErrNilInputParameter) }) T.Run("valid config", func(t *testing.T) { @@ -896,7 +896,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Dimension: 128, Metric: vectorsearch.DistanceCosine, } - require.NoError(t, cfg.ValidateWithContext(t.Context())) + must.NoError(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing base URL", func(t *testing.T) { @@ -905,7 +905,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Dimension: 128, Metric: vectorsearch.DistanceCosine, } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing dimension", func(t *testing.T) { @@ -914,7 +914,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { BaseURL: "http://localhost:6333", Metric: vectorsearch.DistanceCosine, } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("missing metric", func(t *testing.T) { @@ -923,7 +923,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { BaseURL: "http://localhost:6333", Dimension: 128, } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("invalid metric value", func(t *testing.T) { @@ -933,7 +933,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Dimension: 128, Metric: "invalid", } - require.Error(t, cfg.ValidateWithContext(t.Context())) + must.Error(t, cfg.ValidateWithContext(t.Context())) }) T.Run("all valid metrics pass", func(t *testing.T) { @@ -948,7 +948,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Dimension: 128, Metric: m, } - require.NoError(t, cfg.ValidateWithContext(t.Context()), "metric %q should be valid", m) + must.NoError(t, cfg.ValidateWithContext(t.Context()), must.Sprintf("metric %q should be valid", m)) } }) } @@ -968,13 +968,13 @@ func buildContainerBackedQdrant(t *testing.T) (cfg *Config, shutdown func(contex Started: true, } container, err := testcontainers.GenericContainer(ctx, req) - require.NoError(t, err) - require.NotNil(t, container) + must.NoError(t, err) + must.NotNil(t, container) host, err := container.Host(ctx) - require.NoError(t, err) + must.NoError(t, err) port, err := container.MappedPort(ctx, "6333/tcp") - require.NoError(t, err) + must.NoError(t, err) cfg = &Config{ BaseURL: "http://" + net.JoinHostPort(host, port.Port()), @@ -998,7 +998,7 @@ func TestQdrantIndex_Container(T *testing.T) { provide := func(t *testing.T, name string) vectorsearch.Index[doc] { t.Helper() idx, err := ProvideIndex[doc](t.Context(), nil, nil, nil, cfg, name, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) return idx } @@ -1007,18 +1007,18 @@ func TestQdrantIndex_Container(T *testing.T) { ctx := t.Context() idx := provide(t, "rt_"+identifiers.New()) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "11111111-1111-1111-1111-111111111111", Embedding: []float32{1, 0, 0}, Metadata: &doc{Kind: "doc", Title: "alpha"}}, vectorsearch.Vector[doc]{ID: "22222222-2222-2222-2222-222222222222", Embedding: []float32{0, 1, 0}, Metadata: &doc{Kind: "doc", Title: "beta"}}, vectorsearch.Vector[doc]{ID: "33333333-3333-3333-3333-333333333333", Embedding: []float32{0, 0, 1}, Metadata: &doc{Kind: "doc", Title: "gamma"}}, )) results, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 3}) - require.NoError(t, err) - require.Len(t, results, 3) - assert.Equal(t, "11111111-1111-1111-1111-111111111111", results[0].ID) - require.NotNil(t, results[0].Metadata) - assert.Equal(t, "alpha", results[0].Metadata.Title) + must.NoError(t, err) + must.SliceLen(t, 3, results) + test.EqOp(t, "11111111-1111-1111-1111-111111111111", results[0].ID) + must.NotNil(t, results[0].Metadata) + test.EqOp(t, "alpha", results[0].Metadata.Title) }) T.Run("TopK is respected", func(t *testing.T) { @@ -1026,15 +1026,15 @@ func TestQdrantIndex_Container(T *testing.T) { ctx := t.Context() idx := provide(t, "topk_"+identifiers.New()) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "11111111-aaaa-aaaa-aaaa-111111111111", Embedding: []float32{1, 0, 0}}, vectorsearch.Vector[doc]{ID: "22222222-aaaa-aaaa-aaaa-222222222222", Embedding: []float32{0, 1, 0}}, vectorsearch.Vector[doc]{ID: "33333333-aaaa-aaaa-aaaa-333333333333", Embedding: []float32{0, 0, 1}}, )) results, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 2}) - require.NoError(t, err) - assert.Len(t, results, 2) + must.NoError(t, err) + test.SliceLen(t, 2, results) }) T.Run("filter is applied", func(t *testing.T) { @@ -1042,7 +1042,7 @@ func TestQdrantIndex_Container(T *testing.T) { ctx := t.Context() idx := provide(t, "filt_"+identifiers.New()) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "11111111-bbbb-bbbb-bbbb-111111111111", Embedding: []float32{1, 0, 0}, Metadata: &doc{Kind: "doc"}}, vectorsearch.Vector[doc]{ID: "22222222-bbbb-bbbb-bbbb-222222222222", Embedding: []float32{1, 0, 0}, Metadata: &doc{Kind: "image"}}, )) @@ -1056,10 +1056,10 @@ func TestQdrantIndex_Container(T *testing.T) { }, } results, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 10, Filter: filter}) - require.NoError(t, err) - require.Len(t, results, 1) - require.NotNil(t, results[0].Metadata) - assert.Equal(t, "doc", results[0].Metadata.Kind) + must.NoError(t, err) + must.SliceLen(t, 1, results) + must.NotNil(t, results[0].Metadata) + test.EqOp(t, "doc", results[0].Metadata.Kind) }) T.Run("Query rejects empty embedding", func(t *testing.T) { @@ -1068,7 +1068,7 @@ func TestQdrantIndex_Container(T *testing.T) { idx := provide(t, "emb_"+identifiers.New()) _, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: nil, TopK: 5}) - require.ErrorIs(t, err, vectorsearch.ErrEmptyEmbedding) + must.ErrorIs(t, err, vectorsearch.ErrEmptyEmbedding) }) T.Run("Query rejects wrong dimension", func(t *testing.T) { @@ -1077,7 +1077,7 @@ func TestQdrantIndex_Container(T *testing.T) { idx := provide(t, "dim_"+identifiers.New()) _, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0}, TopK: 5}) - require.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) + must.ErrorIs(t, err, vectorsearch.ErrDimensionMismatch) }) T.Run("Delete removes specific points", func(t *testing.T) { @@ -1085,17 +1085,17 @@ func TestQdrantIndex_Container(T *testing.T) { ctx := t.Context() idx := provide(t, "del_"+identifiers.New()) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "11111111-cccc-cccc-cccc-111111111111", Embedding: []float32{1, 0, 0}}, vectorsearch.Vector[doc]{ID: "22222222-cccc-cccc-cccc-222222222222", Embedding: []float32{0, 1, 0}}, )) - require.NoError(t, idx.Delete(ctx, "11111111-cccc-cccc-cccc-111111111111")) + must.NoError(t, idx.Delete(ctx, "11111111-cccc-cccc-cccc-111111111111")) results, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{0, 1, 0}, TopK: 10}) - require.NoError(t, err) - require.Len(t, results, 1) - assert.Equal(t, "22222222-cccc-cccc-cccc-222222222222", results[0].ID) + must.NoError(t, err) + must.SliceLen(t, 1, results) + test.EqOp(t, "22222222-cccc-cccc-cccc-222222222222", results[0].ID) }) T.Run("Wipe drops and recreates", func(t *testing.T) { @@ -1103,17 +1103,17 @@ func TestQdrantIndex_Container(T *testing.T) { ctx := t.Context() idx := provide(t, "wipe_"+identifiers.New()) - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "11111111-dddd-dddd-dddd-111111111111", Embedding: []float32{1, 0, 0}}, )) - require.NoError(t, idx.Wipe(ctx)) + must.NoError(t, idx.Wipe(ctx)) results, err := idx.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 10}) - require.NoError(t, err) - assert.Empty(t, results) + must.NoError(t, err) + test.SliceEmpty(t, results) // Confirm the collection still accepts writes after wipe. - require.NoError(t, idx.Upsert(ctx, + must.NoError(t, idx.Upsert(ctx, vectorsearch.Vector[doc]{ID: "22222222-dddd-dddd-dddd-222222222222", Embedding: []float32{1, 0, 0}}, )) }) @@ -1123,14 +1123,14 @@ func TestQdrantIndex_Container(T *testing.T) { ctx := t.Context() name := "idem_" + identifiers.New() idx1, err := ProvideIndex[doc](ctx, nil, nil, nil, cfg, name, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) idx2, err := ProvideIndex[doc](ctx, nil, nil, nil, cfg, name, cbnoop.NewCircuitBreaker()) - require.NoError(t, err) + must.NoError(t, err) - require.NoError(t, idx1.Upsert(ctx, vectorsearch.Vector[doc]{ID: "11111111-eeee-eeee-eeee-111111111111", Embedding: []float32{1, 0, 0}})) + must.NoError(t, idx1.Upsert(ctx, vectorsearch.Vector[doc]{ID: "11111111-eeee-eeee-eeee-111111111111", Embedding: []float32{1, 0, 0}})) results, err := idx2.Query(ctx, vectorsearch.QueryRequest{Embedding: []float32{1, 0, 0}, TopK: 1}) - require.NoError(t, err) - require.Len(t, results, 1) + must.NoError(t, err) + must.SliceLen(t, 1, results) }) } diff --git a/secrets/config/config_test.go b/secrets/config/config_test.go index eae819f..e28d078 100644 --- a/secrets/config/config_test.go +++ b/secrets/config/config_test.go @@ -17,8 +17,7 @@ import ( awsssm "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -62,55 +61,55 @@ func TestConfig_ValidateWithContext(T *testing.T) { T.Run("valid env provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderEnv} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("valid noop provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderNoop} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("valid gcp provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderGCP, GCP: &gcp.Config{ProjectID: "my-project"}} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("invalid gcp provider missing config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderGCP} - require.Error(t, cfg.ValidateWithContext(context.Background())) + must.Error(t, cfg.ValidateWithContext(context.Background())) }) T.Run("valid ssm provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderSSM, SSM: &ssm.Config{Region: "us-east-1"}} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("invalid ssm provider missing config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderSSM} - require.Error(t, cfg.ValidateWithContext(context.Background())) + must.Error(t, cfg.ValidateWithContext(context.Background())) }) T.Run("valid kubectl provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderKubectl, Kubectl: &kubectl.Config{Namespace: "default"}} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("invalid kubectl provider missing config", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: ProviderKubectl} - require.Error(t, cfg.ValidateWithContext(context.Background())) + must.Error(t, cfg.ValidateWithContext(context.Background())) }) T.Run("unknown provider", func(t *testing.T) { t.Parallel() cfg := &Config{Provider: "vault"} - require.Error(t, cfg.ValidateWithContext(context.Background())) + must.Error(t, cfg.ValidateWithContext(context.Background())) }) } @@ -122,17 +121,17 @@ func TestConfig_ProvideSecretSource(T *testing.T) { var cfg *Config source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) key := "TEST_NIL_CONFIG_" + t.Name() value := "from-env" - require.NoError(t, os.Setenv(key, value)) + must.NoError(t, os.Setenv(key, value)) t.Cleanup(func() { _ = os.Unsetenv(key) }) got, err := source.GetSecret(context.Background(), key) - require.NoError(t, err) - assert.Equal(t, value, got) + must.NoError(t, err) + test.EqOp(t, value, got) }) T.Run("empty provider returns env source", func(t *testing.T) { @@ -140,17 +139,17 @@ func TestConfig_ProvideSecretSource(T *testing.T) { cfg := &Config{Provider: ""} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) key := "TEST_EMPTY_PROVIDER_" + t.Name() value := "from-env" - require.NoError(t, os.Setenv(key, value)) + must.NoError(t, os.Setenv(key, value)) t.Cleanup(func() { _ = os.Unsetenv(key) }) got, err := source.GetSecret(context.Background(), key) - require.NoError(t, err) - assert.Equal(t, value, got) + must.NoError(t, err) + test.EqOp(t, value, got) }) T.Run("env provider returns env source", func(t *testing.T) { @@ -158,17 +157,17 @@ func TestConfig_ProvideSecretSource(T *testing.T) { cfg := &Config{Provider: ProviderEnv} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) key := "TEST_ENV_PROVIDER_" + t.Name() value := "from-env" - require.NoError(t, os.Setenv(key, value)) + must.NoError(t, os.Setenv(key, value)) t.Cleanup(func() { _ = os.Unsetenv(key) }) got, err := source.GetSecret(context.Background(), key) - require.NoError(t, err) - assert.Equal(t, value, got) + must.NoError(t, err) + test.EqOp(t, value, got) }) T.Run("noop provider returns noop source", func(t *testing.T) { @@ -176,12 +175,12 @@ func TestConfig_ProvideSecretSource(T *testing.T) { cfg := &Config{Provider: ProviderNoop} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) got, err := source.GetSecret(context.Background(), "any") - require.NoError(t, err) - assert.Empty(t, got) + must.NoError(t, err) + test.EqOp(t, "", got) }) T.Run("gcp provider with mock client", func(t *testing.T) { @@ -193,12 +192,12 @@ func TestConfig_ProvideSecretSource(T *testing.T) { GCPClient: &mockGCPClient{value: "gcp-secret-value"}, } source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) got, err := source.GetSecret(context.Background(), "MY_SECRET") - require.NoError(t, err) - assert.Equal(t, "gcp-secret-value", got) + must.NoError(t, err) + test.EqOp(t, "gcp-secret-value", got) }) T.Run("ssm provider with mock client", func(t *testing.T) { @@ -210,12 +209,12 @@ func TestConfig_ProvideSecretSource(T *testing.T) { SSMClient: &mockSSMClient{value: "ssm-param-value"}, } source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) got, err := source.GetSecret(context.Background(), "MY_PARAM") - require.NoError(t, err) - assert.Equal(t, "ssm-param-value", got) + must.NoError(t, err) + test.EqOp(t, "ssm-param-value", got) }) T.Run("kubectl provider with mock client", func(t *testing.T) { @@ -233,12 +232,12 @@ func TestConfig_ProvideSecretSource(T *testing.T) { }, } source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) got, err := source.GetSecret(context.Background(), "my-secret/password") - require.NoError(t, err) - assert.Equal(t, "k8s-secret-value", got) + must.NoError(t, err) + test.EqOp(t, "k8s-secret-value", got) }) T.Run("unknown provider returns error", func(t *testing.T) { @@ -246,9 +245,9 @@ func TestConfig_ProvideSecretSource(T *testing.T) { cfg := &Config{Provider: "vault"} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) - assert.Contains(t, err.Error(), "unknown") + must.Error(t, err) + test.Nil(t, source) + test.StrContains(t, err.Error(), "unknown") }) T.Run("gcp provider with nil gcp config returns error", func(t *testing.T) { @@ -256,9 +255,9 @@ func TestConfig_ProvideSecretSource(T *testing.T) { cfg := &Config{Provider: ProviderGCP} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) - assert.Contains(t, err.Error(), "gcp") + must.Error(t, err) + test.Nil(t, source) + test.StrContains(t, err.Error(), "gcp") }) T.Run("ssm provider with nil ssm config returns error", func(t *testing.T) { @@ -266,9 +265,9 @@ func TestConfig_ProvideSecretSource(T *testing.T) { cfg := &Config{Provider: ProviderSSM} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) - assert.Contains(t, err.Error(), "ssm") + must.Error(t, err) + test.Nil(t, source) + test.StrContains(t, err.Error(), "ssm") }) T.Run("kubectl provider with nil kubectl config returns error", func(t *testing.T) { @@ -276,9 +275,9 @@ func TestConfig_ProvideSecretSource(T *testing.T) { cfg := &Config{Provider: ProviderKubectl} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) - assert.Contains(t, err.Error(), "kubectl") + must.Error(t, err) + test.Nil(t, source) + test.StrContains(t, err.Error(), "kubectl") }) T.Run("nil config with metrics error", func(t *testing.T) { @@ -292,8 +291,8 @@ func TestConfig_ProvideSecretSource(T *testing.T) { var cfg *Config source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -309,8 +308,8 @@ func TestConfig_ProvideSecretSource(T *testing.T) { cfg := &Config{Provider: ProviderEnv} source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -330,8 +329,8 @@ func TestConfig_ProvideSecretSource(T *testing.T) { GCPClient: &mockGCPClient{value: "x"}, } source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -351,8 +350,8 @@ func TestConfig_ProvideSecretSource(T *testing.T) { SSMClient: &mockSSMClient{value: "x"}, } source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -372,8 +371,8 @@ func TestConfig_ProvideSecretSource(T *testing.T) { KubectlClient: &mockKubectlClient{secret: &corev1.Secret{}}, } source, err := cfg.ProvideSecretSource(context.Background(), nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) diff --git a/secrets/config/do_test.go b/secrets/config/do_test.go index 1a083b1..167e020 100644 --- a/secrets/config/do_test.go +++ b/secrets/config/do_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/secrets" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterSecretSource(T *testing.T) { @@ -29,7 +29,7 @@ func TestRegisterSecretSource(T *testing.T) { RegisterSecretSource(i) source, err := do.Invoke[secrets.SecretSource](i) - require.NoError(t, err) - assert.NotNil(t, source) + must.NoError(t, err) + test.NotNil(t, source) }) } diff --git a/secrets/config/wire_test.go b/secrets/config/wire_test.go index ead392c..5f0848b 100644 --- a/secrets/config/wire_test.go +++ b/secrets/config/wire_test.go @@ -5,7 +5,7 @@ import ( "os" "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestProvideSecretSourceFromConfig(T *testing.T) { @@ -16,17 +16,17 @@ func TestProvideSecretSourceFromConfig(T *testing.T) { var cfg *Config source, err := ProvideSecretSourceFromConfig(context.Background(), cfg, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) key := "TEST_WIRE_NIL_" + t.Name() value := "from-env" - require.NoError(t, os.Setenv(key, value)) + must.NoError(t, os.Setenv(key, value)) t.Cleanup(func() { _ = os.Unsetenv(key) }) got, err := source.GetSecret(context.Background(), key) - require.NoError(t, err) - require.Equal(t, value, got) + must.NoError(t, err) + must.EqOp(t, value, got) }) T.Run("empty provider returns env source", func(t *testing.T) { @@ -34,17 +34,17 @@ func TestProvideSecretSourceFromConfig(T *testing.T) { cfg := &Config{Provider: ""} source, err := ProvideSecretSourceFromConfig(context.Background(), cfg, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) key := "TEST_WIRE_EMPTY_" + t.Name() value := "from-env" - require.NoError(t, os.Setenv(key, value)) + must.NoError(t, os.Setenv(key, value)) t.Cleanup(func() { _ = os.Unsetenv(key) }) got, err := source.GetSecret(context.Background(), key) - require.NoError(t, err) - require.Equal(t, value, got) + must.NoError(t, err) + must.EqOp(t, value, got) }) T.Run("noop provider returns noop source", func(t *testing.T) { @@ -52,12 +52,12 @@ func TestProvideSecretSourceFromConfig(T *testing.T) { cfg := &Config{Provider: ProviderNoop} source, err := ProvideSecretSourceFromConfig(context.Background(), cfg, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) got, err := source.GetSecret(context.Background(), "any") - require.NoError(t, err) - require.Empty(t, got) + must.NoError(t, err) + must.EqOp(t, "", got) }) T.Run("env provider returns env source", func(t *testing.T) { @@ -65,17 +65,17 @@ func TestProvideSecretSourceFromConfig(T *testing.T) { cfg := &Config{Provider: ProviderEnv} source, err := ProvideSecretSourceFromConfig(context.Background(), cfg, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) key := "TEST_WIRE_ENV_" + t.Name() value := "from-env" - require.NoError(t, os.Setenv(key, value)) + must.NoError(t, os.Setenv(key, value)) t.Cleanup(func() { _ = os.Unsetenv(key) }) got, err := source.GetSecret(context.Background(), key) - require.NoError(t, err) - require.Equal(t, value, got) + must.NoError(t, err) + must.EqOp(t, value, got) }) T.Run("provider error is wrapped", func(t *testing.T) { @@ -83,8 +83,8 @@ func TestProvideSecretSourceFromConfig(T *testing.T) { cfg := &Config{Provider: "vault"} source, err := ProvideSecretSourceFromConfig(context.Background(), cfg, nil, nil, nil) - require.Error(t, err) - require.Nil(t, source) - require.Contains(t, err.Error(), "provide secret source") + must.Error(t, err) + must.Nil(t, source) + must.StrContains(t, err.Error(), "provide secret source") }) } diff --git a/secrets/env/env_test.go b/secrets/env/env_test.go index d7cbdee..969dfcf 100644 --- a/secrets/env/env_test.go +++ b/secrets/env/env_test.go @@ -11,8 +11,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/secrets" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -32,8 +31,8 @@ func TestNewEnvSecretSource(T *testing.T) { } source, err := NewEnvSecretSource(nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -43,7 +42,7 @@ func TestNewEnvSecretSource(T *testing.T) { noopMP := metrics.NewNoopMetricsProvider() h, histErr := noopMP.NewFloat64Histogram("test") - require.NoError(t, histErr) + must.NoError(t, histErr) mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { @@ -56,8 +55,8 @@ func TestNewEnvSecretSource(T *testing.T) { } source, err := NewEnvSecretSource(nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) @@ -72,31 +71,31 @@ func TestEnvSecretSource_GetSecret(T *testing.T) { key := "TEST_SECRET_" + t.Name() value := "secret-value" - require.NoError(t, os.Setenv(key, value)) + must.NoError(t, os.Setenv(key, value)) t.Cleanup(func() { _ = os.Unsetenv(key) }) source, err := NewEnvSecretSource(nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) ctx := context.Background() got, err := source.GetSecret(ctx, key) - require.NoError(t, err) - assert.Equal(t, value, got) + must.NoError(t, err) + test.EqOp(t, value, got) }) T.Run("returns empty for unset env var", func(t *testing.T) { t.Parallel() key := "TEST_SECRET_UNSET_" + t.Name() - require.NoError(t, os.Unsetenv(key)) + must.NoError(t, os.Unsetenv(key)) source, err := NewEnvSecretSource(nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) ctx := context.Background() got, err := source.GetSecret(ctx, key) - require.NoError(t, err) - assert.Empty(t, got) + must.NoError(t, err) + test.EqOp(t, "", got) }) } @@ -107,9 +106,9 @@ func TestEnvSecretSource_Close(T *testing.T) { t.Parallel() source, err := NewEnvSecretSource(nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) err = source.Close() - require.NoError(t, err) + must.NoError(t, err) }) } diff --git a/secrets/gcp/config_test.go b/secrets/gcp/config_test.go index 37f612b..564c6a0 100644 --- a/secrets/gcp/config_test.go +++ b/secrets/gcp/config_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -13,12 +13,12 @@ func TestConfig_ValidateWithContext(T *testing.T) { T.Run("valid", func(t *testing.T) { t.Parallel() cfg := &Config{ProjectID: "my-project"} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("invalid missing ProjectID", func(t *testing.T) { t.Parallel() cfg := &Config{ProjectID: ""} - require.Error(t, cfg.ValidateWithContext(context.Background())) + must.Error(t, cfg.ValidateWithContext(context.Background())) }) } diff --git a/secrets/gcp/gcp_test.go b/secrets/gcp/gcp_test.go index a2f0be0..348d5dd 100644 --- a/secrets/gcp/gcp_test.go +++ b/secrets/gcp/gcp_test.go @@ -10,8 +10,7 @@ import ( "cloud.google.com/go/secretmanager/apiv1/secretmanagerpb" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -21,17 +20,17 @@ func TestNewGCPSecretSource(T *testing.T) { T.Run("nil config returns error", func(t *testing.T) { t.Parallel() source, err := NewGCPSecretSource(context.Background(), nil, nil, nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) - assert.Contains(t, err.Error(), "config is required") + must.Error(t, err) + test.Nil(t, source) + test.StrContains(t, err.Error(), "config is required") }) T.Run("missing ProjectID returns error", func(t *testing.T) { t.Parallel() cfg := &Config{ProjectID: ""} source, err := NewGCPSecretSource(context.Background(), cfg, nil, nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) }) T.Run("with mock client succeeds", func(t *testing.T) { @@ -39,8 +38,8 @@ func TestNewGCPSecretSource(T *testing.T) { cfg := &Config{ProjectID: "test-project"} mc := &mockGCPClient{value: "secret-value"} source, err := NewGCPSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) defer source.Close() }) @@ -56,8 +55,8 @@ func TestNewGCPSecretSource(T *testing.T) { cfg := &Config{ProjectID: "test-project"} source, err := NewGCPSecretSource(context.Background(), cfg, &mockGCPClient{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -80,8 +79,8 @@ func TestNewGCPSecretSource(T *testing.T) { cfg := &Config{ProjectID: "test-project"} source, err := NewGCPSecretSource(context.Background(), cfg, &mockGCPClient{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -91,7 +90,7 @@ func TestNewGCPSecretSource(T *testing.T) { noopMP := metrics.NewNoopMetricsProvider() h, histErr := noopMP.NewFloat64Histogram("test") - require.NoError(t, histErr) + must.NoError(t, histErr) mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { @@ -105,8 +104,8 @@ func TestNewGCPSecretSource(T *testing.T) { cfg := &Config{ProjectID: "test-project"} source, err := NewGCPSecretSource(context.Background(), cfg, &mockGCPClient{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) @@ -121,12 +120,12 @@ func TestGCPSecretSource_GetSecret(T *testing.T) { cfg := &Config{ProjectID: "test-project"} mc := &mockGCPClient{value: "my-secret-value"} source, err := NewGCPSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) defer source.Close() got, err := source.GetSecret(context.Background(), "MY_SECRET") - require.NoError(t, err) - assert.Equal(t, "my-secret-value", got) + must.NoError(t, err) + test.EqOp(t, "my-secret-value", got) }) T.Run("error from client", func(t *testing.T) { @@ -134,12 +133,12 @@ func TestGCPSecretSource_GetSecret(T *testing.T) { cfg := &Config{ProjectID: "test-project"} mc := &mockGCPClient{err: errors.New("gcp error")} source, err := NewGCPSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) defer source.Close() _, err = source.GetSecret(context.Background(), "MY_SECRET") - require.Error(t, err) - assert.Contains(t, err.Error(), "gcp error") + must.Error(t, err) + test.StrContains(t, err.Error(), "gcp error") }) T.Run("full resource name passed through", func(t *testing.T) { @@ -147,12 +146,12 @@ func TestGCPSecretSource_GetSecret(T *testing.T) { cfg := &Config{ProjectID: "test-project"} mc := &mockGCPClient{value: "full-name-secret"} source, err := NewGCPSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) defer source.Close() got, err := source.GetSecret(context.Background(), "projects/other-project/secrets/foo/versions/latest") - require.NoError(t, err) - assert.Equal(t, "full-name-secret", got) + must.NoError(t, err) + test.EqOp(t, "full-name-secret", got) }) } @@ -165,11 +164,11 @@ func TestGCPSecretSource_Close(T *testing.T) { cfg := &Config{ProjectID: "test-project"} mc := &mockGCPClient{} source, err := NewGCPSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) err = source.Close() - require.NoError(t, err) - assert.True(t, mc.closed) + must.NoError(t, err) + test.True(t, mc.closed) }) } diff --git a/secrets/kubectl/config_test.go b/secrets/kubectl/config_test.go index 0340b83..c29d0af 100644 --- a/secrets/kubectl/config_test.go +++ b/secrets/kubectl/config_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -14,20 +14,20 @@ func TestConfig_ValidateWithContext(T *testing.T) { T.Run("valid config", func(t *testing.T) { t.Parallel() cfg := &Config{Namespace: "default"} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("valid config with kubeconfig", func(t *testing.T) { t.Parallel() cfg := &Config{Namespace: "production", Kubeconfig: "/home/user/.kube/config"} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("missing namespace", func(t *testing.T) { t.Parallel() cfg := &Config{} err := cfg.ValidateWithContext(context.Background()) - require.Error(t, err) - assert.Contains(t, err.Error(), "namespace") + must.Error(t, err) + test.StrContains(t, err.Error(), "namespace") }) } diff --git a/secrets/kubectl/kubectl_test.go b/secrets/kubectl/kubectl_test.go index 91c6ef3..202d5a3 100644 --- a/secrets/kubectl/kubectl_test.go +++ b/secrets/kubectl/kubectl_test.go @@ -9,8 +9,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -22,17 +21,17 @@ func TestNewKubectlSecretSource(T *testing.T) { T.Run("nil config returns error", func(t *testing.T) { t.Parallel() source, err := NewKubectlSecretSource(context.Background(), nil, nil, nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) - assert.Contains(t, err.Error(), "config is required") + must.Error(t, err) + test.Nil(t, source) + test.StrContains(t, err.Error(), "config is required") }) T.Run("missing namespace returns error", func(t *testing.T) { t.Parallel() cfg := &Config{} source, err := NewKubectlSecretSource(context.Background(), cfg, nil, nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) }) T.Run("with mock client succeeds", func(t *testing.T) { @@ -40,8 +39,8 @@ func TestNewKubectlSecretSource(T *testing.T) { cfg := &Config{Namespace: "default"} mc := &mockSecretGetter{} source, err := NewKubectlSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) }) T.Run("with error creating lookup counter", func(t *testing.T) { @@ -56,8 +55,8 @@ func TestNewKubectlSecretSource(T *testing.T) { cfg := &Config{Namespace: "default"} source, err := NewKubectlSecretSource(context.Background(), cfg, &mockSecretGetter{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -80,8 +79,8 @@ func TestNewKubectlSecretSource(T *testing.T) { cfg := &Config{Namespace: "default"} source, err := NewKubectlSecretSource(context.Background(), cfg, &mockSecretGetter{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -91,7 +90,7 @@ func TestNewKubectlSecretSource(T *testing.T) { noopMP := metrics.NewNoopMetricsProvider() h, histErr := noopMP.NewFloat64Histogram("test") - require.NoError(t, histErr) + must.NoError(t, histErr) mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { @@ -105,8 +104,8 @@ func TestNewKubectlSecretSource(T *testing.T) { cfg := &Config{Namespace: "default"} source, err := NewKubectlSecretSource(context.Background(), cfg, &mockSecretGetter{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) @@ -127,12 +126,12 @@ func TestKubectlSecretSource_GetSecret(T *testing.T) { }, } source, err := NewKubectlSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) got, err := source.GetSecret(context.Background(), "db-creds/password") - require.NoError(t, err) - assert.Equal(t, "s3cret", got) - assert.Equal(t, "db-creds", mc.lastName) + must.NoError(t, err) + test.EqOp(t, "s3cret", got) + test.EqOp(t, "db-creds", mc.lastName) }) T.Run("missing slash in name", func(t *testing.T) { @@ -140,11 +139,11 @@ func TestKubectlSecretSource_GetSecret(T *testing.T) { cfg := &Config{Namespace: "default"} mc := &mockSecretGetter{} source, err := NewKubectlSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) _, err = source.GetSecret(context.Background(), "no-slash") - require.Error(t, err) - assert.Contains(t, err.Error(), "expected format") + must.Error(t, err) + test.StrContains(t, err.Error(), "expected format") }) T.Run("key not found", func(t *testing.T) { @@ -158,11 +157,11 @@ func TestKubectlSecretSource_GetSecret(T *testing.T) { }, } source, err := NewKubectlSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) _, err = source.GetSecret(context.Background(), "db-creds/password") - require.Error(t, err) - assert.Contains(t, err.Error(), "key \"password\" not found") + must.Error(t, err) + test.StrContains(t, err.Error(), "key \"password\" not found") }) T.Run("client error", func(t *testing.T) { @@ -170,11 +169,11 @@ func TestKubectlSecretSource_GetSecret(T *testing.T) { cfg := &Config{Namespace: "default"} mc := &mockSecretGetter{err: errors.New("k8s api error")} source, err := NewKubectlSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) _, err = source.GetSecret(context.Background(), "db-creds/password") - require.Error(t, err) - assert.Contains(t, err.Error(), "k8s api error") + must.Error(t, err) + test.StrContains(t, err.Error(), "k8s api error") }) } @@ -186,10 +185,10 @@ func TestKubectlSecretSource_Close(T *testing.T) { cfg := &Config{Namespace: "default"} mc := &mockSecretGetter{} source, err := NewKubectlSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) err = source.Close() - require.NoError(t, err) + must.NoError(t, err) }) } @@ -199,24 +198,24 @@ func TestResolveName(T *testing.T) { T.Run("valid name", func(t *testing.T) { t.Parallel() secretName, key, err := resolveName("my-secret/my-key") - require.NoError(t, err) - assert.Equal(t, "my-secret", secretName) - assert.Equal(t, "my-key", key) + must.NoError(t, err) + test.EqOp(t, "my-secret", secretName) + test.EqOp(t, "my-key", key) }) T.Run("name with multiple slashes", func(t *testing.T) { t.Parallel() secretName, key, err := resolveName("my-secret/nested/key") - require.NoError(t, err) - assert.Equal(t, "my-secret", secretName) - assert.Equal(t, "nested/key", key) + must.NoError(t, err) + test.EqOp(t, "my-secret", secretName) + test.EqOp(t, "nested/key", key) }) T.Run("no slash", func(t *testing.T) { t.Parallel() _, _, err := resolveName("no-slash") - require.Error(t, err) - assert.Contains(t, err.Error(), "expected format") + must.Error(t, err) + test.StrContains(t, err.Error(), "expected format") }) } diff --git a/secrets/noop/noop_test.go b/secrets/noop/noop_test.go index eb2f2c9..c793d81 100644 --- a/secrets/noop/noop_test.go +++ b/secrets/noop/noop_test.go @@ -4,8 +4,8 @@ import ( "context" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestSecretSource_GetSecret(T *testing.T) { @@ -18,8 +18,8 @@ func TestSecretSource_GetSecret(T *testing.T) { ctx := context.Background() got, err := source.GetSecret(ctx, "any-key") - require.NoError(t, err) - assert.Empty(t, got) + must.NoError(t, err) + test.EqOp(t, "", got) }) } @@ -31,6 +31,6 @@ func TestSecretSource_Close(T *testing.T) { source := NewSecretSource() err := source.Close() - require.NoError(t, err) + must.NoError(t, err) }) } diff --git a/secrets/ssm/config_test.go b/secrets/ssm/config_test.go index 178fc6b..c40c67c 100644 --- a/secrets/ssm/config_test.go +++ b/secrets/ssm/config_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -13,12 +13,12 @@ func TestConfig_ValidateWithContext(T *testing.T) { T.Run("valid", func(t *testing.T) { t.Parallel() cfg := &Config{Region: "us-east-1"} - require.NoError(t, cfg.ValidateWithContext(context.Background())) + must.NoError(t, cfg.ValidateWithContext(context.Background())) }) T.Run("invalid missing Region", func(t *testing.T) { t.Parallel() cfg := &Config{Region: ""} - require.Error(t, cfg.ValidateWithContext(context.Background())) + must.Error(t, cfg.ValidateWithContext(context.Background())) }) } diff --git a/secrets/ssm/ssm_test.go b/secrets/ssm/ssm_test.go index 1fb16e2..5e1033d 100644 --- a/secrets/ssm/ssm_test.go +++ b/secrets/ssm/ssm_test.go @@ -12,8 +12,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/shoenig/test" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" ) @@ -23,17 +22,17 @@ func TestNewSSMSecretSource(T *testing.T) { T.Run("nil config returns error", func(t *testing.T) { t.Parallel() source, err := NewSSMSecretSource(context.Background(), nil, nil, nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) - assert.Contains(t, err.Error(), "config is required") + must.Error(t, err) + test.Nil(t, source) + test.StrContains(t, err.Error(), "config is required") }) T.Run("missing Region returns error", func(t *testing.T) { t.Parallel() cfg := &Config{Region: ""} source, err := NewSSMSecretSource(context.Background(), cfg, nil, nil, nil, nil) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) }) T.Run("with mock client succeeds", func(t *testing.T) { @@ -41,8 +40,8 @@ func TestNewSSMSecretSource(T *testing.T) { cfg := &Config{Region: "us-east-1"} mc := &mockSSMClient{value: "param-value"} source, err := NewSSMSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) - require.NotNil(t, source) + must.NoError(t, err) + must.NotNil(t, source) }) T.Run("with error creating lookup counter", func(t *testing.T) { @@ -57,8 +56,8 @@ func TestNewSSMSecretSource(T *testing.T) { cfg := &Config{Region: "us-east-1"} source, err := NewSSMSecretSource(context.Background(), cfg, &mockSSMClient{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) @@ -81,8 +80,8 @@ func TestNewSSMSecretSource(T *testing.T) { cfg := &Config{Region: "us-east-1"} source, err := NewSSMSecretSource(context.Background(), cfg, &mockSSMClient{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) @@ -92,7 +91,7 @@ func TestNewSSMSecretSource(T *testing.T) { noopMP := metrics.NewNoopMetricsProvider() h, histErr := noopMP.NewFloat64Histogram("test") - require.NoError(t, histErr) + must.NoError(t, histErr) mp := &mockmetrics.ProviderMock{ NewInt64CounterFunc: func(_ string, _ ...metric.Int64CounterOption) (metrics.Int64Counter, error) { @@ -106,8 +105,8 @@ func TestNewSSMSecretSource(T *testing.T) { cfg := &Config{Region: "us-east-1"} source, err := NewSSMSecretSource(context.Background(), cfg, &mockSSMClient{}, nil, nil, mp) - require.Error(t, err) - assert.Nil(t, source) + must.Error(t, err) + test.Nil(t, source) test.SliceLen(t, 2, mp.NewInt64CounterCalls()) test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) @@ -122,11 +121,11 @@ func TestSSMSecretSource_GetSecret(T *testing.T) { cfg := &Config{Region: "us-east-1"} mc := &mockSSMClient{value: "my-param-value"} source, err := NewSSMSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) got, err := source.GetSecret(context.Background(), "MY_PARAM") - require.NoError(t, err) - assert.Equal(t, "my-param-value", got) + must.NoError(t, err) + test.EqOp(t, "my-param-value", got) }) T.Run("error from client", func(t *testing.T) { @@ -134,11 +133,11 @@ func TestSSMSecretSource_GetSecret(T *testing.T) { cfg := &Config{Region: "us-east-1"} mc := &mockSSMClient{err: errors.New("ssm error")} source, err := NewSSMSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) _, err = source.GetSecret(context.Background(), "MY_PARAM") - require.Error(t, err) - assert.Contains(t, err.Error(), "ssm error") + must.Error(t, err) + test.StrContains(t, err.Error(), "ssm error") }) T.Run("name with prefix", func(t *testing.T) { @@ -146,12 +145,12 @@ func TestSSMSecretSource_GetSecret(T *testing.T) { cfg := &Config{Region: "us-east-1", Prefix: "/myapp/"} mc := &mockSSMClient{value: "prefixed-value"} source, err := NewSSMSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) got, err := source.GetSecret(context.Background(), "MY_PARAM") - require.NoError(t, err) - assert.Equal(t, "prefixed-value", got) - assert.Equal(t, "/myapp/MY_PARAM", mc.lastName) + must.NoError(t, err) + test.EqOp(t, "prefixed-value", got) + test.EqOp(t, "/myapp/MY_PARAM", mc.lastName) }) T.Run("name already path", func(t *testing.T) { @@ -159,12 +158,12 @@ func TestSSMSecretSource_GetSecret(T *testing.T) { cfg := &Config{Region: "us-east-1", Prefix: "/myapp/"} mc := &mockSSMClient{value: "path-value"} source, err := NewSSMSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) got, err := source.GetSecret(context.Background(), "/existing/path/param") - require.NoError(t, err) - assert.Equal(t, "path-value", got) - assert.Equal(t, "/existing/path/param", mc.lastName) + must.NoError(t, err) + test.EqOp(t, "path-value", got) + test.EqOp(t, "/existing/path/param", mc.lastName) }) } @@ -177,10 +176,10 @@ func TestSSMSecretSource_Close(T *testing.T) { cfg := &Config{Region: "us-east-1"} mc := &mockSSMClient{} source, err := NewSSMSecretSource(context.Background(), cfg, mc, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) err = source.Close() - require.NoError(t, err) + must.NoError(t, err) }) } diff --git a/server/grpc/do_test.go b/server/grpc/do_test.go index 96e01d4..d673e1a 100644 --- a/server/grpc/do_test.go +++ b/server/grpc/do_test.go @@ -7,8 +7,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "google.golang.org/grpc" ) @@ -29,7 +29,7 @@ func TestRegisterGRPCServer(T *testing.T) { RegisterGRPCServer(i) srv, err := do.Invoke[*Server](i) - require.NoError(t, err) - assert.NotNil(t, srv) + must.NoError(t, err) + test.NotNil(t, srv) }) } diff --git a/server/grpc/server_test.go b/server/grpc/server_test.go index bcf4a5b..86b41ef 100644 --- a/server/grpc/server_test.go +++ b/server/grpc/server_test.go @@ -18,13 +18,15 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop" "google.golang.org/grpc" ) +var errStub = errors.New("stub error") + type mockTracerProvider struct { noop.TracerProvider forceFlushFunc func(ctx context.Context) error @@ -47,7 +49,7 @@ func generateTestTLSCerts(t *testing.T) (certFile, keyFile string) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) + must.NoError(t, err) template := x509.Certificate{ SerialNumber: big.NewInt(1), @@ -59,23 +61,23 @@ func generateTestTLSCerts(t *testing.T) (certFile, keyFile string) { } certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - require.NoError(t, err) + must.NoError(t, err) dir := t.TempDir() certPath := filepath.Join(dir, "cert.pem") certOut, err := os.Create(certPath) - require.NoError(t, err) - require.NoError(t, pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) - require.NoError(t, certOut.Close()) + must.NoError(t, err) + must.NoError(t, pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + must.NoError(t, certOut.Close()) keyDER, err := x509.MarshalECPrivateKey(key) - require.NoError(t, err) + must.NoError(t, err) keyPath := filepath.Join(dir, "key.pem") keyOut, err := os.Create(keyPath) - require.NoError(t, err) - require.NoError(t, pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) - require.NoError(t, keyOut.Close()) + must.NoError(t, err) + must.NoError(t, pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + must.NoError(t, keyOut.Close()) return certPath, keyPath } @@ -88,8 +90,8 @@ func TestNewGRPCServer(T *testing.T) { server, err := NewGRPCServer(nil, nil, nil, nil, nil) - assert.Nil(t, server) - assert.Error(t, err) + test.Nil(t, server) + test.Error(t, err) }) T.Run("succeeds with valid config", func(t *testing.T) { @@ -98,8 +100,8 @@ func TestNewGRPCServer(T *testing.T) { cfg := &Config{Port: 0} server, err := NewGRPCServer(cfg, nil, nil, nil, nil) - require.NoError(t, err) - assert.NotNil(t, server) + must.NoError(t, err) + test.NotNil(t, server) }) T.Run("succeeds with registration functions", func(t *testing.T) { @@ -113,9 +115,9 @@ func TestNewGRPCServer(T *testing.T) { cfg := &Config{Port: 0} server, err := NewGRPCServer(cfg, logging.NewNoopLogger(), nil, nil, nil, rf) - require.NoError(t, err) - assert.NotNil(t, server) - assert.True(t, called) + must.NoError(t, err) + test.NotNil(t, server) + test.True(t, called) }) T.Run("returns error with invalid TLS files", func(t *testing.T) { @@ -129,8 +131,8 @@ func TestNewGRPCServer(T *testing.T) { server, err := NewGRPCServer(cfg, nil, nil, nil, nil) - assert.Nil(t, server) - assert.Error(t, err) + test.Nil(t, server) + test.Error(t, err) }) T.Run("succeeds with valid TLS files", func(t *testing.T) { @@ -146,8 +148,8 @@ func TestNewGRPCServer(T *testing.T) { server, err := NewGRPCServer(cfg, logging.NewNoopLogger(), nil, nil, nil) - require.NoError(t, err) - assert.NotNil(t, server) + must.NoError(t, err) + test.NotNil(t, server) }) } @@ -158,7 +160,7 @@ func TestLoggingInterceptor(T *testing.T) { t.Parallel() interceptor := LoggingInterceptor(nil) - assert.NotNil(t, interceptor) + test.NotNil(t, interceptor) handlerCalled := false handler := func(ctx context.Context, req any) (any, error) { @@ -169,18 +171,18 @@ func TestLoggingInterceptor(T *testing.T) { info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} result, err := interceptor(context.Background(), "request", info, handler) - assert.NoError(t, err) - assert.Equal(t, "result", result) - assert.True(t, handlerCalled) + test.NoError(t, err) + test.Eq(t, "result", result) + test.True(t, handlerCalled) }) T.Run("logs error when handler fails", func(t *testing.T) { t.Parallel() interceptor := LoggingInterceptor(logging.NewNoopLogger()) - assert.NotNil(t, interceptor) + test.NotNil(t, interceptor) - expectedErr := assert.AnError + expectedErr := errStub handler := func(ctx context.Context, req any) (any, error) { return nil, expectedErr } @@ -188,8 +190,8 @@ func TestLoggingInterceptor(T *testing.T) { info := &grpc.UnaryServerInfo{FullMethod: "/test/Method"} result, err := interceptor(context.Background(), "request", info, handler) - assert.ErrorIs(t, err, expectedErr) - assert.Nil(t, result) + test.ErrorIs(t, err, expectedErr) + test.Nil(t, result) }) } @@ -201,7 +203,7 @@ func TestServer_Shutdown(T *testing.T) { cfg := &Config{Port: 0} server, err := NewGRPCServer(cfg, nil, nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) server.Shutdown(context.Background()) }) @@ -211,7 +213,7 @@ func TestServer_Shutdown(T *testing.T) { cfg := &Config{Port: 0} server, err := NewGRPCServer(cfg, logging.NewNoopLogger(), nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) server.Shutdown(context.Background()) }) @@ -225,11 +227,11 @@ func TestServer_Shutdown(T *testing.T) { cfg := &Config{Port: 0} srv, err := NewGRPCServer(cfg, logging.NewNoopLogger(), mtp, nil, nil) - require.NoError(t, err) + must.NoError(t, err) srv.Shutdown(context.Background()) - assert.Equal(t, 1, mtp.forceFlushCalls) + test.EqOp(t, 1, mtp.forceFlushCalls) }) } @@ -246,8 +248,8 @@ func TestNewGRPCServer_withInterceptors(T *testing.T) { cfg := &Config{Port: 0} server, err := NewGRPCServer(cfg, logging.NewNoopLogger(), nil, []grpc.UnaryServerInterceptor{unaryInterceptor}, nil) - require.NoError(t, err) - assert.NotNil(t, server) + must.NoError(t, err) + test.NotNil(t, server) }) T.Run("with stream interceptors", func(t *testing.T) { @@ -260,8 +262,8 @@ func TestNewGRPCServer_withInterceptors(T *testing.T) { cfg := &Config{Port: 0} server, err := NewGRPCServer(cfg, logging.NewNoopLogger(), nil, nil, []grpc.StreamServerInterceptor{streamInterceptor}) - require.NoError(t, err) - assert.NotNil(t, server) + must.NoError(t, err) + test.NotNil(t, server) }) T.Run("with multiple registration functions", func(t *testing.T) { @@ -274,9 +276,9 @@ func TestNewGRPCServer_withInterceptors(T *testing.T) { cfg := &Config{Port: 0} server, err := NewGRPCServer(cfg, nil, nil, nil, nil, rf1, rf2) - require.NoError(t, err) - assert.NotNil(t, server) - assert.Equal(t, 2, callCount) + must.NoError(t, err) + test.NotNil(t, server) + test.EqOp(t, 2, callCount) }) } @@ -288,7 +290,7 @@ func TestServer_Serve(T *testing.T) { cfg := &Config{Port: 0} srv, err := NewGRPCServer(cfg, logging.NewNoopLogger(), nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) ctx, cancel := context.WithCancel(t.Context()) @@ -310,14 +312,14 @@ func TestServer_Serve(T *testing.T) { // Occupy a port so the server's Listen call fails with "address already in use". lis, err := new(net.ListenConfig).Listen(t.Context(), "tcp", ":0") - require.NoError(t, err) + must.NoError(t, err) defer lis.Close() port := lis.Addr().(*net.TCPAddr).Port cfg := &Config{Port: uint16(port)} srv, err := NewGRPCServer(cfg, logging.NewNoopLogger(), nil, nil, nil) - require.NoError(t, err) + must.NoError(t, err) // Should return immediately because the port is already in use. srv.Serve(t.Context()) diff --git a/server/http/config_test.go b/server/http/config_test.go index d3e51b9..98ca15c 100644 --- a/server/http/config_test.go +++ b/server/http/config_test.go @@ -4,7 +4,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_Validate(T *testing.T) { @@ -20,7 +20,7 @@ func TestConfig_Validate(T *testing.T) { Debug: true, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("returns error with missing port", func(t *testing.T) { @@ -31,7 +31,7 @@ func TestConfig_Validate(T *testing.T) { StartupDeadline: time.Second, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("returns error with missing startup deadline", func(t *testing.T) { @@ -42,7 +42,7 @@ func TestConfig_Validate(T *testing.T) { Port: 8080, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("returns error with empty config", func(t *testing.T) { @@ -51,6 +51,6 @@ func TestConfig_Validate(T *testing.T) { ctx := t.Context() cfg := &Config{} - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/server/http/do_test.go b/server/http/do_test.go index 46f4968..42db750 100644 --- a/server/http/do_test.go +++ b/server/http/do_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/routing" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterHTTPServer(T *testing.T) { @@ -28,7 +28,7 @@ func TestRegisterHTTPServer(T *testing.T) { RegisterHTTPServer(i, "test_service") srv, err := do.Invoke[Server](i) - require.NoError(t, err) - assert.NotNil(t, srv) + must.NoError(t, err) + test.NotNil(t, srv) }) } diff --git a/server/http/http_server_test.go b/server/http/http_server_test.go index 82e31f6..7dc76b6 100644 --- a/server/http/http_server_test.go +++ b/server/http/http_server_test.go @@ -24,8 +24,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/panicking" "github.com/verygoodsoftwarenotvirus/platform/v5/routing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace/noop" ) @@ -74,7 +74,7 @@ func generateTestTLSCerts(t *testing.T) (certFile, keyFile string) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - require.NoError(t, err) + must.NoError(t, err) template := x509.Certificate{ SerialNumber: big.NewInt(1), @@ -86,23 +86,23 @@ func generateTestTLSCerts(t *testing.T) (certFile, keyFile string) { } certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - require.NoError(t, err) + must.NoError(t, err) dir := t.TempDir() certPath := filepath.Join(dir, "cert.pem") certOut, err := os.Create(certPath) - require.NoError(t, err) - require.NoError(t, pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) - require.NoError(t, certOut.Close()) + must.NoError(t, err) + must.NoError(t, pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + must.NoError(t, certOut.Close()) keyDER, err := x509.MarshalECPrivateKey(key) - require.NoError(t, err) + must.NoError(t, err) keyPath := filepath.Join(dir, "key.pem") keyOut, err := os.Create(keyPath) - require.NoError(t, err) - require.NoError(t, pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) - require.NoError(t, keyOut.Close()) + must.NoError(t, err) + must.NoError(t, pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + must.NoError(t, keyOut.Close()) return certPath, keyPath } @@ -127,8 +127,8 @@ func TestProvideHTTPServer(T *testing.T) { "", ) - assert.NotNil(t, x) - assert.NoError(t, err) + test.NotNil(t, x) + test.NoError(t, err) }) T.Run("with custom service name", func(t *testing.T) { @@ -142,8 +142,8 @@ func TestProvideHTTPServer(T *testing.T) { "custom_service", ) - assert.NotNil(t, x) - assert.NoError(t, err) + test.NotNil(t, x) + test.NoError(t, err) }) T.Run("with empty service name uses default", func(t *testing.T) { @@ -157,8 +157,8 @@ func TestProvideHTTPServer(T *testing.T) { "", ) - assert.NotNil(t, x) - assert.NoError(t, err) + test.NotNil(t, x) + test.NoError(t, err) }) T.Run("with SSL config", func(t *testing.T) { @@ -176,8 +176,8 @@ func TestProvideHTTPServer(T *testing.T) { "", ) - assert.NotNil(t, x) - assert.NoError(t, err) + test.NotNil(t, x) + test.NoError(t, err) }) } @@ -188,10 +188,10 @@ func TestServer_Router(T *testing.T) { t.Parallel() s, err := ProvideHTTPServer(Config{Port: 0}, nil, nil, nil, "") - require.NoError(t, err) + must.NoError(t, err) // Router returns nil when nil was passed in - assert.Nil(t, s.Router()) + test.Nil(t, s.Router()) }) } @@ -202,12 +202,12 @@ func TestServer_Shutdown(T *testing.T) { t.Parallel() s, err := ProvideHTTPServer(Config{Port: 0}, logging.NewNoopLogger(), nil, nil, "") - require.NoError(t, err) + must.NoError(t, err) ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() - assert.NoError(t, s.Shutdown(ctx)) + test.NoError(t, s.Shutdown(ctx)) }) T.Run("logs error when ForceFlush fails", func(t *testing.T) { @@ -218,14 +218,14 @@ func TestServer_Shutdown(T *testing.T) { } s, err := ProvideHTTPServer(Config{Port: 0}, logging.NewNoopLogger(), nil, mtp, "") - require.NoError(t, err) + must.NoError(t, err) ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() - assert.NoError(t, s.Shutdown(ctx)) + test.NoError(t, s.Shutdown(ctx)) - assert.Equal(t, 1, mtp.forceFlushCalls) + test.EqOp(t, 1, mtp.forceFlushCalls) }) } @@ -252,7 +252,7 @@ func TestServer_Serve(T *testing.T) { // Give the server time to start listening. time.Sleep(50 * time.Millisecond) - require.NoError(t, srv.httpServer.Close()) + must.NoError(t, srv.httpServer.Close()) <-done }) @@ -280,7 +280,7 @@ func TestServer_Serve(T *testing.T) { }() time.Sleep(50 * time.Millisecond) - require.NoError(t, srv.httpServer.Close()) + must.NoError(t, srv.httpServer.Close()) <-done }) @@ -308,7 +308,7 @@ func TestServer_Serve(T *testing.T) { // Occupy a port so ListenAndServe fails with "address already in use". lis, err := new(net.ListenConfig).Listen(t.Context(), "tcp", ":0") - require.NoError(t, err) + must.NoError(t, err) defer lis.Close() port := lis.Addr().(*net.TCPAddr).Port @@ -335,28 +335,28 @@ func Test_skipNoisePaths(T *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/_ops_/health", http.NoBody) - assert.False(t, skipNoisePaths(req)) + test.False(t, skipNoisePaths(req)) }) T.Run("apple app site association path is filtered out", func(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, appleAppSiteAssociationPath, http.NoBody) - assert.False(t, skipNoisePaths(req)) + test.False(t, skipNoisePaths(req)) }) T.Run("normal paths are not filtered", func(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/api/v1/things", http.NoBody) - assert.True(t, skipNoisePaths(req)) + test.True(t, skipNoisePaths(req)) }) T.Run("root path is not filtered", func(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) - assert.True(t, skipNoisePaths(req)) + test.True(t, skipNoisePaths(req)) }) } @@ -368,13 +368,13 @@ func Test_provideStdLibHTTPServer(T *testing.T) { srv := provideStdLibHTTPServer(8080) - assert.NotNil(t, srv) - assert.Equal(t, ":8080", srv.Addr) - assert.Equal(t, readTimeout, srv.ReadTimeout) - assert.Equal(t, writeTimeout, srv.WriteTimeout) - assert.Equal(t, idleTimeout, srv.IdleTimeout) - assert.NotNil(t, srv.TLSConfig) - assert.Equal(t, uint16(tls.VersionTLS12), srv.TLSConfig.MinVersion) + test.NotNil(t, srv) + test.EqOp(t, ":8080", srv.Addr) + test.EqOp(t, readTimeout, srv.ReadTimeout) + test.EqOp(t, writeTimeout, srv.WriteTimeout) + test.EqOp(t, idleTimeout, srv.IdleTimeout) + test.NotNil(t, srv.TLSConfig) + test.EqOp(t, uint16(tls.VersionTLS12), srv.TLSConfig.MinVersion) }) T.Run("with zero port", func(t *testing.T) { @@ -382,7 +382,7 @@ func Test_provideStdLibHTTPServer(T *testing.T) { srv := provideStdLibHTTPServer(0) - assert.NotNil(t, srv) - assert.Equal(t, ":0", srv.Addr) + test.NotNil(t, srv) + test.EqOp(t, ":0", srv.Addr) }) } diff --git a/server/http/static_files_test.go b/server/http/static_files_test.go index 7fa81f2..725a916 100644 --- a/server/http/static_files_test.go +++ b/server/http/static_files_test.go @@ -7,8 +7,8 @@ import ( "path/filepath" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRootLevelAssetsHandler(T *testing.T) { @@ -19,7 +19,7 @@ func TestRootLevelAssetsHandler(T *testing.T) { dir := t.TempDir() err := os.WriteFile(filepath.Join(dir, "robots.txt"), []byte("User-agent: *"), 0o600) - require.NoError(t, err) + must.NoError(t, err) handler := RootLevelAssetsHandler(dir) req := httptest.NewRequest(http.MethodGet, "/robots.txt", http.NoBody) @@ -27,8 +27,8 @@ func TestRootLevelAssetsHandler(T *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "User-agent") + test.EqOp(t, http.StatusOK, w.Code) + test.StrContains(t, w.Body.String(), "User-agent") }) T.Run("returns 404 for subdirectory paths", func(t *testing.T) { @@ -41,7 +41,7 @@ func TestRootLevelAssetsHandler(T *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, http.StatusNotFound, w.Code) + test.EqOp(t, http.StatusNotFound, w.Code) }) T.Run("returns 404 for nonexistent file", func(t *testing.T) { @@ -54,7 +54,7 @@ func TestRootLevelAssetsHandler(T *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, http.StatusNotFound, w.Code) + test.EqOp(t, http.StatusNotFound, w.Code) }) T.Run("returns 404 for directory", func(t *testing.T) { @@ -62,7 +62,7 @@ func TestRootLevelAssetsHandler(T *testing.T) { dir := t.TempDir() err := os.Mkdir(filepath.Join(dir, "subdir"), 0o755) - require.NoError(t, err) + must.NoError(t, err) handler := RootLevelAssetsHandler(dir) req := httptest.NewRequest(http.MethodGet, "/subdir", http.NoBody) @@ -70,7 +70,7 @@ func TestRootLevelAssetsHandler(T *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, http.StatusNotFound, w.Code) + test.EqOp(t, http.StatusNotFound, w.Code) }) T.Run("blocks path traversal attempts", func(t *testing.T) { @@ -83,7 +83,7 @@ func TestRootLevelAssetsHandler(T *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, http.StatusNotFound, w.Code) + test.EqOp(t, http.StatusNotFound, w.Code) }) T.Run("blocks single-segment traversal", func(t *testing.T) { @@ -98,6 +98,6 @@ func TestRootLevelAssetsHandler(T *testing.T) { handler.ServeHTTP(w, req) - assert.Equal(t, http.StatusNotFound, w.Code) + test.EqOp(t, http.StatusNotFound, w.Code) }) } diff --git a/testutils/testutil.go b/testutils/testutil.go index 2f4bb6d..209c194 100644 --- a/testutils/testutil.go +++ b/testutils/testutil.go @@ -12,8 +12,8 @@ import ( "time" fake "github.com/brianvoe/gofakeit/v7" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func init() { @@ -65,8 +65,8 @@ func BuildTestRequest(t *testing.T) *http.Request { http.NoBody, ) - require.NotNil(t, req) - assert.NoError(t, err) + must.NotNil(t, req) + test.NoError(t, err) return req } diff --git a/testutils/testutil_test.go b/testutils/testutil_test.go index 3ccc874..fbf86be 100644 --- a/testutils/testutil_test.go +++ b/testutils/testutil_test.go @@ -3,8 +3,8 @@ package testutils import ( "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestBuildArbitraryImage(T *testing.T) { @@ -13,17 +13,17 @@ func TestBuildArbitraryImage(T *testing.T) { T.Run("returns image with correct dimensions", func(t *testing.T) { t.Parallel() img := BuildArbitraryImage(10) - require.NotNil(t, img) + must.NotNil(t, img) bounds := img.Bounds() - assert.Equal(t, 10, bounds.Dx()) - assert.Equal(t, 10, bounds.Dy()) + test.EqOp(t, 10, bounds.Dx()) + test.EqOp(t, 10, bounds.Dy()) }) T.Run("handles size 1", func(t *testing.T) { t.Parallel() img := BuildArbitraryImage(1) - require.NotNil(t, img) - assert.Equal(t, 1, img.Bounds().Dx()) + must.NotNil(t, img) + test.EqOp(t, 1, img.Bounds().Dx()) }) } @@ -33,13 +33,13 @@ func TestBuildArbitraryImagePNGBytes(T *testing.T) { T.Run("returns valid PNG bytes", func(t *testing.T) { t.Parallel() img, data := BuildArbitraryImagePNGBytes(5) - require.NotNil(t, img) - assert.NotEmpty(t, data) + must.NotNil(t, img) + test.SliceNotEmpty(t, data) // PNG magic bytes - assert.Equal(t, byte(0x89), data[0]) - assert.Equal(t, byte('P'), data[1]) - assert.Equal(t, byte('N'), data[2]) - assert.Equal(t, byte('G'), data[3]) + test.EqOp(t, byte(0x89), data[0]) + test.EqOp(t, byte('P'), data[1]) + test.EqOp(t, byte('N'), data[2]) + test.EqOp(t, byte('G'), data[3]) }) } @@ -49,7 +49,7 @@ func TestBuildTestRequest(T *testing.T) { T.Run("returns valid request", func(t *testing.T) { t.Parallel() req := BuildTestRequest(t) - require.NotNil(t, req) - assert.NotNil(t, req.Context()) + must.NotNil(t, req) + test.NotNil(t, req.Context()) }) } diff --git a/types/main_test.go b/types/main_test.go index b072aa6..24116df 100644 --- a/types/main_test.go +++ b/types/main_test.go @@ -7,8 +7,8 @@ import ( "time" fake "github.com/brianvoe/gofakeit/v7" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func init() { @@ -21,7 +21,7 @@ func TestErrorResponse_Error(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - assert.NotEmpty(t, (&APIError{}).Error()) + test.NotEq(t, "", (&APIError{}).Error()) }) } @@ -39,12 +39,12 @@ func TestAPIResponse_EncodeToJSON(T *testing.T) { } encodedBytes, err := json.Marshal(example) - require.NoError(t, err) + must.NoError(t, err) expected := `{"error":{"message":"TestAPIResponse_EncodeToJSON/standard","code":"E104"},"details":{"currentAccountID":"","traceID":""}}` actual := string(encodedBytes) - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) } @@ -55,7 +55,7 @@ func TestAPIError_AsError(T *testing.T) { t.Parallel() var e *APIError - assert.NoError(t, e.AsError()) + test.NoError(t, e.AsError()) }) T.Run("with non-nil receiver", func(t *testing.T) { @@ -65,7 +65,7 @@ func TestAPIError_AsError(T *testing.T) { Message: "something went wrong", Code: ErrNothingSpecific, } - assert.Error(t, e.AsError()) + test.Error(t, e.AsError()) }) } @@ -82,11 +82,11 @@ func TestNewAPIErrorResponse(T *testing.T) { resp := NewAPIErrorResponse("something broke", ErrTalkingToDatabase, details) - require.NotNil(t, resp) - require.NotNil(t, resp.Error) - assert.Equal(t, "something broke", resp.Error.Message) - assert.Equal(t, ErrTalkingToDatabase, resp.Error.Code) - assert.Equal(t, details, resp.Details) + must.NotNil(t, resp) + must.NotNil(t, resp.Error) + test.EqOp(t, "something broke", resp.Error.Message) + test.EqOp(t, ErrTalkingToDatabase, resp.Error.Code) + test.EqOp(t, details, resp.Details) }) } @@ -97,14 +97,14 @@ func TestFloat32RangeWithOptionalMax_ValidateWithContext(T *testing.T) { t.Parallel() x := &Float32RangeWithOptionalMax{Min: 1.0} - assert.NoError(t, x.ValidateWithContext(context.Background())) + test.NoError(t, x.ValidateWithContext(context.Background())) }) T.Run("invalid", func(t *testing.T) { t.Parallel() x := &Float32RangeWithOptionalMax{} - assert.Error(t, x.ValidateWithContext(context.Background())) + test.Error(t, x.ValidateWithContext(context.Background())) }) } @@ -115,14 +115,14 @@ func TestUint16RangeWithOptionalMax_ValidateWithContext(T *testing.T) { t.Parallel() x := &Uint16RangeWithOptionalMax{Min: 1} - assert.NoError(t, x.ValidateWithContext(context.Background())) + test.NoError(t, x.ValidateWithContext(context.Background())) }) T.Run("invalid", func(t *testing.T) { t.Parallel() x := &Uint16RangeWithOptionalMax{} - assert.Error(t, x.ValidateWithContext(context.Background())) + test.Error(t, x.ValidateWithContext(context.Background())) }) } @@ -133,14 +133,14 @@ func TestUint32RangeWithOptionalMax_ValidateWithContext(T *testing.T) { t.Parallel() x := &Uint32RangeWithOptionalMax{Min: 1} - assert.NoError(t, x.ValidateWithContext(context.Background())) + test.NoError(t, x.ValidateWithContext(context.Background())) }) T.Run("invalid", func(t *testing.T) { t.Parallel() x := &Uint32RangeWithOptionalMax{} - assert.Error(t, x.ValidateWithContext(context.Background())) + test.Error(t, x.ValidateWithContext(context.Background())) }) } @@ -151,13 +151,13 @@ func TestRangeWithOptionalUpperBound_ValidateWithContext(T *testing.T) { t.Parallel() x := &RangeWithOptionalUpperBound[string]{Min: "a"} - assert.NoError(t, x.ValidateWithContext(context.Background())) + test.NoError(t, x.ValidateWithContext(context.Background())) }) T.Run("invalid", func(t *testing.T) { t.Parallel() x := &RangeWithOptionalUpperBound[string]{} - assert.Error(t, x.ValidateWithContext(context.Background())) + test.Error(t, x.ValidateWithContext(context.Background())) }) } diff --git a/uploads/config/config_test.go b/uploads/config/config_test.go index fef690a..c38b451 100644 --- a/uploads/config/config_test.go +++ b/uploads/config/config_test.go @@ -5,7 +5,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/uploads/objectstorage" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -26,7 +26,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Debug: false, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with empty storage", func(t *testing.T) { @@ -35,6 +35,6 @@ func TestConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &Config{} - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/uploads/config/do_test.go b/uploads/config/do_test.go index 4e7970a..d0c95b9 100644 --- a/uploads/config/do_test.go +++ b/uploads/config/do_test.go @@ -6,8 +6,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/uploads/objectstorage" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterStorageConfig(T *testing.T) { @@ -27,8 +27,8 @@ func TestRegisterStorageConfig(T *testing.T) { RegisterStorageConfig(i) storageCfg, err := do.Invoke[*objectstorage.Config](i) - require.NoError(t, err) - assert.NotNil(t, storageCfg) - assert.Equal(t, t.Name(), storageCfg.BucketName) + must.NoError(t, err) + test.NotNil(t, storageCfg) + test.EqOp(t, t.Name(), storageCfg.BucketName) }) } diff --git a/uploads/images/images_test.go b/uploads/images/images_test.go index 58be1ba..d692115 100644 --- a/uploads/images/images_test.go +++ b/uploads/images/images_test.go @@ -19,8 +19,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/verygoodsoftwarenotvirus/platform/v5/testutils" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) // errorWriter is an http.ResponseWriter whose Write always returns an error. @@ -45,15 +45,15 @@ func newAvatarUploadRequest(t *testing.T, filename string, avatar io.Reader) *ht writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("avatar", fmt.Sprintf("avatar.%s", filepath.Ext(filename))) - require.NoError(t, err) + must.NoError(t, err) _, err = io.Copy(part, avatar) - require.NoError(t, err) + must.NoError(t, err) - require.NoError(t, writer.Close()) + must.NoError(t, writer.Close()) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", body) - require.NoError(t, err) + must.NoError(t, err) req.Header.Set(headerContentType, writer.FormDataContentType()) @@ -65,7 +65,7 @@ func buildPNGBytes(t *testing.T) *bytes.Buffer { b := new(bytes.Buffer) exampleImage := testutils.BuildArbitraryImage(256) - require.NoError(t, png.Encode(b, exampleImage)) + must.NoError(t, png.Encode(b, exampleImage)) expected := b.Bytes() return bytes.NewBuffer(expected) @@ -76,7 +76,7 @@ func buildJPEGBytes(t *testing.T) *bytes.Buffer { b := new(bytes.Buffer) exampleImage := testutils.BuildArbitraryImage(256) - require.NoError(t, jpeg.Encode(b, exampleImage, &jpeg.Options{Quality: jpeg.DefaultQuality})) + must.NoError(t, jpeg.Encode(b, exampleImage, &jpeg.Options{Quality: jpeg.DefaultQuality})) expected := b.Bytes() return bytes.NewBuffer(expected) @@ -87,7 +87,7 @@ func buildGIFBytes(t *testing.T) *bytes.Buffer { b := new(bytes.Buffer) exampleImage := testutils.BuildArbitraryImage(256) - require.NoError(t, gif.Encode(b, exampleImage, &gif.Options{NumColors: 256})) + must.NoError(t, gif.Encode(b, exampleImage, &gif.Options{NumColors: 256})) expected := b.Bytes() return bytes.NewBuffer(expected) @@ -102,15 +102,15 @@ func newMultiFileUploadRequest(t *testing.T, files map[string][]byte) *http.Requ for filename, data := range files { part, err := writer.CreateFormFile(filename, filename) - require.NoError(t, err) + must.NoError(t, err) _, err = io.Copy(part, bytes.NewReader(data)) - require.NoError(t, err) + must.NoError(t, err) } - require.NoError(t, writer.Close()) + must.NoError(t, writer.Close()) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", body) - require.NoError(t, err) + must.NoError(t, err) req.Header.Set(headerContentType, writer.FormDataContentType()) @@ -123,33 +123,33 @@ func Test_contentTypeFromFilename(T *testing.T) { T.Run("png", func(t *testing.T) { t.Parallel() - assert.Equal(t, imagePNG, contentTypeFromFilename("photo.png")) + test.EqOp(t, imagePNG, contentTypeFromFilename("photo.png")) }) T.Run("jpeg", func(t *testing.T) { t.Parallel() - assert.Equal(t, imageJPEG, contentTypeFromFilename("photo.jpeg")) + test.EqOp(t, imageJPEG, contentTypeFromFilename("photo.jpeg")) }) T.Run("gif", func(t *testing.T) { t.Parallel() - assert.Equal(t, imageGIF, contentTypeFromFilename("photo.gif")) + test.EqOp(t, imageGIF, contentTypeFromFilename("photo.gif")) }) T.Run("falls back to mime.TypeByExtension", func(t *testing.T) { t.Parallel() actual := contentTypeFromFilename("document.html") - assert.Contains(t, actual, "text/html") + test.StrContains(t, actual, "text/html") }) T.Run("unknown extension", func(t *testing.T) { t.Parallel() actual := contentTypeFromFilename("file.xyznotreal") - assert.Empty(t, actual) + test.EqOp(t, "", actual) }) } @@ -159,31 +159,31 @@ func Test_isImage(T *testing.T) { T.Run("png", func(t *testing.T) { t.Parallel() - assert.True(t, isImage("photo.png")) + test.True(t, isImage("photo.png")) }) T.Run("jpeg", func(t *testing.T) { t.Parallel() - assert.True(t, isImage("photo.jpeg")) + test.True(t, isImage("photo.jpeg")) }) T.Run("gif", func(t *testing.T) { t.Parallel() - assert.True(t, isImage("photo.gif")) + test.True(t, isImage("photo.gif")) }) T.Run("non-image", func(t *testing.T) { t.Parallel() - assert.False(t, isImage("document.html")) + test.False(t, isImage("document.html")) }) T.Run("unknown extension", func(t *testing.T) { t.Parallel() - assert.False(t, isImage("file.xyznotreal")) + test.False(t, isImage("file.xyznotreal")) }) } @@ -194,7 +194,7 @@ func TestNewMediaUploadProcessor(T *testing.T) { t.Parallel() p := NewMediaUploadProcessor(nil, tracing.NewNoopTracerProvider()) - assert.NotNil(t, p) + test.NotNil(t, p) }) } @@ -214,7 +214,7 @@ func TestImage_DataURI(T *testing.T) { expected := "data:things/stuff;base64,VGVzdEltYWdlX0RhdGFVUkkvc3RhbmRhcmQ=" actual := i.DataURI() - assert.Equal(t, expected, actual) + test.EqOp(t, expected, actual) }) } @@ -233,11 +233,11 @@ func TestImage_Write(T *testing.T) { } res := httptest.NewRecorder() - assert.NoError(t, i.Write(res)) + test.NoError(t, i.Write(res)) - assert.Equal(t, "things/stuff", res.Header().Get(headerContentType)) - assert.Equal(t, strconv.Itoa(len(data)), res.Header().Get("RawHTML-Length")) - assert.Equal(t, data, res.Body.Bytes()) + test.EqOp(t, "things/stuff", res.Header().Get(headerContentType)) + test.EqOp(t, strconv.Itoa(len(data)), res.Header().Get("RawHTML-Length")) + test.Eq(t, data, res.Body.Bytes()) }) T.Run("with write error", func(t *testing.T) { @@ -252,7 +252,7 @@ func TestImage_Write(T *testing.T) { res := &errorWriter{} - assert.Error(t, i.Write(res)) + test.Error(t, i.Write(res)) }) } @@ -272,13 +272,13 @@ func TestImage_Thumbnail(T *testing.T) { } tempFile, err := os.CreateTemp("", "") - require.NoError(t, err) + must.NoError(t, err) actual, err := i.Thumbnail(123, 123, tempFile.Name()) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) - require.NoError(t, os.Remove(tempFile.Name())) + must.NoError(t, os.Remove(tempFile.Name())) }) T.Run("with invalid content type", func(t *testing.T) { @@ -289,8 +289,8 @@ func TestImage_Thumbnail(T *testing.T) { } actual, err := i.Thumbnail(123, 123, t.Name()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) } @@ -334,10 +334,10 @@ func Test_uploadProcessor_Process(T *testing.T) { req := newAvatarUploadRequest(t, "avatar.png", imgBytes) actual, err := p.ProcessFile(ctx, req, expectedFieldName) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) - assert.Equal(t, expected, actual.Data) + test.Eq(t, expected, actual.Data) }) T.Run("with missing form file", func(t *testing.T) { @@ -348,11 +348,11 @@ func Test_uploadProcessor_Process(T *testing.T) { expectedFieldName := "avatar" req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://tests.verygoodsoftwarenotvirus.ru", http.NoBody) - require.NoError(t, err) + must.NoError(t, err) actual, err := p.ProcessFile(ctx, req, expectedFieldName) - assert.Nil(t, actual) - assert.Error(t, err) + test.Nil(t, actual) + test.Error(t, err) }) T.Run("with error decoding image", func(t *testing.T) { @@ -365,8 +365,8 @@ func Test_uploadProcessor_Process(T *testing.T) { req := newAvatarUploadRequest(t, "avatar.png", bytes.NewBufferString("")) actual, err := p.ProcessFile(ctx, req, expectedFieldName) - assert.Nil(t, actual) - assert.Error(t, err) + test.Nil(t, actual) + test.Error(t, err) }) T.Run("with non-image file", func(t *testing.T) { @@ -379,18 +379,18 @@ func Test_uploadProcessor_Process(T *testing.T) { body := &bytes.Buffer{} writer := multipart.NewWriter(body) part, err := writer.CreateFormFile(expectedFieldName, "notes.txt") - require.NoError(t, err) + must.NoError(t, err) _, err = part.Write([]byte("hello world")) - require.NoError(t, err) - require.NoError(t, writer.Close()) + must.NoError(t, err) + must.NoError(t, writer.Close()) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", body) - require.NoError(t, err) + must.NoError(t, err) req.Header.Set(headerContentType, writer.FormDataContentType()) actual, err := p.ProcessFile(ctx, req, expectedFieldName) - assert.NotNil(t, actual) - assert.NoError(t, err) + test.NotNil(t, actual) + test.NoError(t, err) }) } @@ -409,8 +409,8 @@ func Test_uploadProcessor_ProcessFiles(T *testing.T) { }) actual, err := p.ProcessFiles(ctx, req, "upload") - assert.NoError(t, err) - assert.Len(t, actual, 1) + test.NoError(t, err) + test.SliceLen(t, 1, actual) }) T.Run("standard with multiple files", func(t *testing.T) { @@ -427,8 +427,8 @@ func Test_uploadProcessor_ProcessFiles(T *testing.T) { }) actual, err := p.ProcessFiles(ctx, req, "upload") - assert.NoError(t, err) - assert.Len(t, actual, 2) + test.NoError(t, err) + test.SliceLen(t, 2, actual) }) T.Run("with no multipart form", func(t *testing.T) { @@ -438,11 +438,11 @@ func Test_uploadProcessor_ProcessFiles(T *testing.T) { p := NewMediaUploadProcessor(nil, tracing.NewNoopTracerProvider()) req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://whatever.whocares.gov", http.NoBody) - require.NoError(t, err) + must.NoError(t, err) actual, err := p.ProcessFiles(ctx, req, "upload") - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) T.Run("with invalid image data", func(t *testing.T) { @@ -456,8 +456,8 @@ func Test_uploadProcessor_ProcessFiles(T *testing.T) { }) actual, err := p.ProcessFiles(ctx, req, "upload") - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) }) T.Run("with non-image files", func(t *testing.T) { @@ -471,8 +471,8 @@ func Test_uploadProcessor_ProcessFiles(T *testing.T) { }) actual, err := p.ProcessFiles(ctx, req, "upload") - assert.NoError(t, err) - assert.Len(t, actual, 1) + test.NoError(t, err) + test.SliceLen(t, 1, actual) }) T.Run("with already parsed multipart form", func(t *testing.T) { @@ -486,10 +486,10 @@ func Test_uploadProcessor_ProcessFiles(T *testing.T) { "photo.png": imgBytes, }) - require.NoError(t, req.ParseMultipartForm(defaultMaxMemory)) + must.NoError(t, req.ParseMultipartForm(defaultMaxMemory)) actual, err := p.ProcessFiles(ctx, req, "upload") - assert.NoError(t, err) - assert.Len(t, actual, 1) + test.NoError(t, err) + test.SliceLen(t, 1, actual) }) } diff --git a/uploads/images/thumbnails_test.go b/uploads/images/thumbnails_test.go index f3e30b9..6632abb 100644 --- a/uploads/images/thumbnails_test.go +++ b/uploads/images/thumbnails_test.go @@ -4,8 +4,8 @@ import ( "os" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func Test_newThumbnailer(T *testing.T) { @@ -16,8 +16,8 @@ func Test_newThumbnailer(T *testing.T) { for _, ct := range []string{imagePNG, imageJPEG, imageGIF} { x, err := newThumbnailer(ct) - assert.NoError(t, err) - assert.NotNil(t, x) + test.NoError(t, err) + test.NotNil(t, x) } }) @@ -25,8 +25,8 @@ func Test_newThumbnailer(T *testing.T) { t.Parallel() x, err := newThumbnailer(t.Name()) - assert.Error(t, err) - assert.Nil(t, x) + test.Error(t, err) + test.Nil(t, x) }) } @@ -45,8 +45,8 @@ func Test_preprocess(T *testing.T) { } img, err := preprocess(i, 128, 128) - assert.NoError(t, err) - assert.NotNil(t, img) + test.NoError(t, err) + test.NotNil(t, img) }) T.Run("with invalid content", func(t *testing.T) { @@ -60,8 +60,8 @@ func Test_preprocess(T *testing.T) { } img, err := preprocess(i, 128, 128) - assert.Error(t, err) - assert.Nil(t, img) + test.Error(t, err) + test.Nil(t, img) }) } @@ -80,13 +80,13 @@ func Test_jpegThumbnailer_Thumbnail(T *testing.T) { } tempFile, err := os.CreateTemp("", "") - require.NoError(t, err) + must.NoError(t, err) actual, err := (&jpegThumbnailer{}).Thumbnail(i, 128, 128, tempFile.Name()) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) - require.NoError(t, os.Remove(tempFile.Name())) + must.NoError(t, os.Remove(tempFile.Name())) }) T.Run("with invalid content", func(t *testing.T) { @@ -100,13 +100,13 @@ func Test_jpegThumbnailer_Thumbnail(T *testing.T) { } tempFile, err := os.CreateTemp("", "") - require.NoError(t, err) + must.NoError(t, err) actual, err := (&jpegThumbnailer{}).Thumbnail(i, 128, 128, tempFile.Name()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) - require.NoError(t, os.Remove(tempFile.Name())) + must.NoError(t, os.Remove(tempFile.Name())) }) } @@ -125,13 +125,13 @@ func Test_pngThumbnailer_Thumbnail(T *testing.T) { } tempFile, err := os.CreateTemp("", "") - require.NoError(t, err) + must.NoError(t, err) actual, err := (&pngThumbnailer{}).Thumbnail(i, 128, 128, tempFile.Name()) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) - require.NoError(t, os.Remove(tempFile.Name())) + must.NoError(t, os.Remove(tempFile.Name())) }) T.Run("with invalid content", func(t *testing.T) { @@ -145,13 +145,13 @@ func Test_pngThumbnailer_Thumbnail(T *testing.T) { } tempFile, err := os.CreateTemp("", "") - require.NoError(t, err) + must.NoError(t, err) actual, err := (&pngThumbnailer{}).Thumbnail(i, 128, 128, tempFile.Name()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) - require.NoError(t, os.Remove(tempFile.Name())) + must.NoError(t, os.Remove(tempFile.Name())) }) } @@ -170,13 +170,13 @@ func Test_gifThumbnailer_Thumbnail(T *testing.T) { } tempFile, err := os.CreateTemp("", "") - require.NoError(t, err) + must.NoError(t, err) actual, err := (&gifThumbnailer{}).Thumbnail(i, 128, 128, tempFile.Name()) - assert.NoError(t, err) - assert.NotNil(t, actual) + test.NoError(t, err) + test.NotNil(t, actual) - require.NoError(t, os.Remove(tempFile.Name())) + must.NoError(t, os.Remove(tempFile.Name())) }) T.Run("with invalid content", func(t *testing.T) { @@ -190,12 +190,12 @@ func Test_gifThumbnailer_Thumbnail(T *testing.T) { } tempFile, err := os.CreateTemp("", "") - require.NoError(t, err) + must.NoError(t, err) actual, err := (&gifThumbnailer{}).Thumbnail(i, 128, 128, tempFile.Name()) - assert.Error(t, err) - assert.Nil(t, actual) + test.Error(t, err) + test.Nil(t, actual) - require.NoError(t, os.Remove(tempFile.Name())) + must.NoError(t, os.Remove(tempFile.Name())) }) } diff --git a/uploads/objectstorage/bucket_backblaze_test.go b/uploads/objectstorage/bucket_backblaze_test.go index 4c67d15..1350e0b 100644 --- a/uploads/objectstorage/bucket_backblaze_test.go +++ b/uploads/objectstorage/bucket_backblaze_test.go @@ -3,7 +3,7 @@ package objectstorage import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestBackblazeB2Config_ValidateWithContext(T *testing.T) { @@ -20,7 +20,7 @@ func TestBackblazeB2Config_ValidateWithContext(T *testing.T) { Region: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing application key ID", func(t *testing.T) { @@ -33,7 +33,7 @@ func TestBackblazeB2Config_ValidateWithContext(T *testing.T) { Region: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing application key", func(t *testing.T) { @@ -46,7 +46,7 @@ func TestBackblazeB2Config_ValidateWithContext(T *testing.T) { Region: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing bucket name", func(t *testing.T) { @@ -59,7 +59,7 @@ func TestBackblazeB2Config_ValidateWithContext(T *testing.T) { Region: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing region", func(t *testing.T) { @@ -72,6 +72,6 @@ func TestBackblazeB2Config_ValidateWithContext(T *testing.T) { BucketName: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/uploads/objectstorage/bucket_filesystem_test.go b/uploads/objectstorage/bucket_filesystem_test.go index 20721e0..8a60bf7 100644 --- a/uploads/objectstorage/bucket_filesystem_test.go +++ b/uploads/objectstorage/bucket_filesystem_test.go @@ -3,7 +3,7 @@ package objectstorage import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestFilesystemConfig_ValidateWithContext(T *testing.T) { @@ -17,7 +17,7 @@ func TestFilesystemConfig_ValidateWithContext(T *testing.T) { RootDirectory: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing root directory", func(t *testing.T) { @@ -26,6 +26,6 @@ func TestFilesystemConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &FilesystemConfig{} - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/uploads/objectstorage/bucket_gcp_test.go b/uploads/objectstorage/bucket_gcp_test.go index 9f2446d..573069a 100644 --- a/uploads/objectstorage/bucket_gcp_test.go +++ b/uploads/objectstorage/bucket_gcp_test.go @@ -3,7 +3,7 @@ package objectstorage import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestGCPConfig_ValidateWithContext(T *testing.T) { @@ -17,7 +17,7 @@ func TestGCPConfig_ValidateWithContext(T *testing.T) { BucketName: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing bucket name", func(t *testing.T) { @@ -26,6 +26,6 @@ func TestGCPConfig_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &GCPConfig{} - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/uploads/objectstorage/bucket_r2_test.go b/uploads/objectstorage/bucket_r2_test.go index e80ffbd..5691318 100644 --- a/uploads/objectstorage/bucket_r2_test.go +++ b/uploads/objectstorage/bucket_r2_test.go @@ -3,7 +3,7 @@ package objectstorage import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestR2Config_ValidateWithContext(T *testing.T) { @@ -20,7 +20,7 @@ func TestR2Config_ValidateWithContext(T *testing.T) { SecretAccessKey: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing account ID", func(t *testing.T) { @@ -33,7 +33,7 @@ func TestR2Config_ValidateWithContext(T *testing.T) { SecretAccessKey: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing bucket name", func(t *testing.T) { @@ -46,7 +46,7 @@ func TestR2Config_ValidateWithContext(T *testing.T) { SecretAccessKey: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing access key ID", func(t *testing.T) { @@ -59,7 +59,7 @@ func TestR2Config_ValidateWithContext(T *testing.T) { SecretAccessKey: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing secret access key", func(t *testing.T) { @@ -72,6 +72,6 @@ func TestR2Config_ValidateWithContext(T *testing.T) { AccessKeyID: t.Name(), } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/uploads/objectstorage/bucket_s3_test.go b/uploads/objectstorage/bucket_s3_test.go index 2260a4d..1b3a751 100644 --- a/uploads/objectstorage/bucket_s3_test.go +++ b/uploads/objectstorage/bucket_s3_test.go @@ -3,7 +3,7 @@ package objectstorage import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestS3Config_ValidateWithContext(T *testing.T) { @@ -17,7 +17,7 @@ func TestS3Config_ValidateWithContext(T *testing.T) { BucketName: t.Name(), } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing bucket name", func(t *testing.T) { @@ -26,6 +26,6 @@ func TestS3Config_ValidateWithContext(T *testing.T) { ctx := t.Context() cfg := &S3Config{} - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } diff --git a/uploads/objectstorage/do_test.go b/uploads/objectstorage/do_test.go index 316a093..0f7e640 100644 --- a/uploads/objectstorage/do_test.go +++ b/uploads/objectstorage/do_test.go @@ -9,8 +9,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/uploads" "github.com/samber/do/v2" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" ) func TestRegisterUploadManager(T *testing.T) { @@ -32,11 +32,11 @@ func TestRegisterUploadManager(T *testing.T) { RegisterUploadManager(i) uploader, err := do.Invoke[*Uploader](i) - require.NoError(t, err) - assert.NotNil(t, uploader) + must.NoError(t, err) + test.NotNil(t, uploader) uploadManager, err := do.Invoke[uploads.UploadManager](i) - require.NoError(t, err) - assert.NotNil(t, uploadManager) + must.NoError(t, err) + test.NotNil(t, uploadManager) }) } diff --git a/uploads/objectstorage/files_test.go b/uploads/objectstorage/files_test.go index 60edeb2..7181abc 100644 --- a/uploads/objectstorage/files_test.go +++ b/uploads/objectstorage/files_test.go @@ -10,8 +10,8 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/shoenig/test" + "github.com/shoenig/test/must" "gocloud.dev/blob/memblob" ) @@ -20,19 +20,19 @@ func noopUploaderMetrics(t *testing.T) (saveCounter, readCounter, saveErrCounter mp := metrics.NewNoopMetricsProvider() saveCounter, err := mp.NewInt64Counter("test_saves") - require.NoError(t, err) + must.NoError(t, err) readCounter, err = mp.NewInt64Counter("test_reads") - require.NoError(t, err) + must.NoError(t, err) saveErrCounter, err = mp.NewInt64Counter("test_save_errors") - require.NoError(t, err) + must.NoError(t, err) readErrCounter, err = mp.NewInt64Counter("test_read_errors") - require.NoError(t, err) + must.NoError(t, err) latencyHist, err = mp.NewFloat64Histogram("test_latency") - require.NoError(t, err) + must.NoError(t, err) return saveCounter, readCounter, saveErrCounter, readErrCounter, latencyHist } @@ -48,7 +48,7 @@ func TestUploader_ReadFile(T *testing.T) { expectedContent := []byte(t.Name()) b := memblob.OpenBucket(&memblob.Options{}) - require.NoError(t, b.WriteAll(ctx, exampleFilename, expectedContent, nil)) + must.NoError(t, b.WriteAll(ctx, exampleFilename, expectedContent, nil)) saveCounter, readCounter, saveErrCounter, readErrCounter, latencyHist := noopUploaderMetrics(t) u := &Uploader{ @@ -64,8 +64,8 @@ func TestUploader_ReadFile(T *testing.T) { } x, err := u.ReadFile(ctx, exampleFilename) - assert.NoError(t, err) - assert.Equal(t, expectedContent, x) + test.NoError(t, err) + test.Eq(t, expectedContent, x) }) T.Run("with invalid file", func(t *testing.T) { @@ -88,8 +88,8 @@ func TestUploader_ReadFile(T *testing.T) { } x, err := u.ReadFile(ctx, exampleFilename) - assert.Error(t, err) - assert.Nil(t, x) + test.Error(t, err) + test.Nil(t, x) }) T.Run("with broken circuit breaker", func(t *testing.T) { @@ -115,8 +115,8 @@ func TestUploader_ReadFile(T *testing.T) { } x, err := u.ReadFile(ctx, "anything.txt") - assert.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) - assert.Nil(t, x) + test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.Nil(t, x) }) T.Run("with mock circuit breaker on successful read", func(t *testing.T) { @@ -127,7 +127,7 @@ func TestUploader_ReadFile(T *testing.T) { expectedContent := []byte(t.Name()) b := memblob.OpenBucket(&memblob.Options{}) - require.NoError(t, b.WriteAll(ctx, exampleFilename, expectedContent, nil)) + must.NoError(t, b.WriteAll(ctx, exampleFilename, expectedContent, nil)) cb := &cbmock.CircuitBreakerMock{ CannotProceedFunc: func() bool { return false }, @@ -148,8 +148,8 @@ func TestUploader_ReadFile(T *testing.T) { } x, err := u.ReadFile(ctx, exampleFilename) - assert.NoError(t, err) - assert.Equal(t, expectedContent, x) + test.NoError(t, err) + test.Eq(t, expectedContent, x) }) } @@ -173,7 +173,7 @@ func TestUploader_SaveFile(T *testing.T) { latencyHist: latencyHist, } - assert.NoError(t, u.SaveFile(ctx, "test_file.txt", []byte(t.Name()))) + test.NoError(t, u.SaveFile(ctx, "test_file.txt", []byte(t.Name()))) }) T.Run("with broken circuit breaker", func(t *testing.T) { @@ -198,7 +198,7 @@ func TestUploader_SaveFile(T *testing.T) { latencyHist: latencyHist, } - assert.ErrorIs(t, u.SaveFile(ctx, "test_file.txt", []byte(t.Name())), circuitbreaking.ErrCircuitBroken) + test.ErrorIs(t, u.SaveFile(ctx, "test_file.txt", []byte(t.Name())), circuitbreaking.ErrCircuitBroken) }) T.Run("with write error", func(t *testing.T) { @@ -212,7 +212,7 @@ func TestUploader_SaveFile(T *testing.T) { } b := memblob.OpenBucket(&memblob.Options{}) - require.NoError(t, b.Close()) + must.NoError(t, b.Close()) saveCounter, readCounter, saveErrCounter, readErrCounter, latencyHist := noopUploaderMetrics(t) u := &Uploader{ @@ -227,7 +227,7 @@ func TestUploader_SaveFile(T *testing.T) { latencyHist: latencyHist, } - assert.Error(t, u.SaveFile(ctx, "test_file.txt", []byte(t.Name()))) + test.Error(t, u.SaveFile(ctx, "test_file.txt", []byte(t.Name()))) }) T.Run("can be read back after save", func(t *testing.T) { @@ -249,10 +249,10 @@ func TestUploader_SaveFile(T *testing.T) { latencyHist: latencyHist, } - require.NoError(t, u.SaveFile(ctx, "roundtrip.txt", content)) + must.NoError(t, u.SaveFile(ctx, "roundtrip.txt", content)) actual, err := u.ReadFile(ctx, "roundtrip.txt") - assert.NoError(t, err) - assert.Equal(t, content, actual) + test.NoError(t, err) + test.Eq(t, content, actual) }) } diff --git a/uploads/objectstorage/providers_test.go b/uploads/objectstorage/providers_test.go index 2e04b8d..81e1494 100644 --- a/uploads/objectstorage/providers_test.go +++ b/uploads/objectstorage/providers_test.go @@ -3,7 +3,7 @@ package objectstorage import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" "gocloud.dev/blob/memblob" ) @@ -18,7 +18,7 @@ func TestProvideUploadManager(T *testing.T) { } result := ProvideUploadManager(u) - assert.NotNil(t, result) - assert.Equal(t, u, result) + test.NotNil(t, result) + test.True(t, u == result) }) } diff --git a/uploads/objectstorage/uploader_test.go b/uploads/objectstorage/uploader_test.go index 0a4e447..373b8d6 100644 --- a/uploads/objectstorage/uploader_test.go +++ b/uploads/objectstorage/uploader_test.go @@ -8,7 +8,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestConfig_ValidateWithContext(T *testing.T) { @@ -24,7 +24,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { FilesystemConfig: &FilesystemConfig{RootDirectory: t.Name()}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with missing bucket name", func(t *testing.T) { @@ -35,7 +35,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: MemoryProvider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with invalid provider", func(t *testing.T) { @@ -47,7 +47,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: "invalid_provider", } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with s3 provider", func(t *testing.T) { @@ -60,7 +60,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { S3Config: &S3Config{BucketName: t.Name()}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with s3 provider missing config", func(t *testing.T) { @@ -72,7 +72,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: S3Provider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with gcp provider", func(t *testing.T) { @@ -85,7 +85,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { GCP: &GCPConfig{BucketName: t.Name()}, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with gcp provider missing config", func(t *testing.T) { @@ -97,7 +97,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: GCPCloudStorageProvider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with r2 provider", func(t *testing.T) { @@ -115,7 +115,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with r2 provider missing config", func(t *testing.T) { @@ -127,7 +127,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: R2Provider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with backblaze b2 provider", func(t *testing.T) { @@ -145,7 +145,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { }, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with backblaze b2 provider missing config", func(t *testing.T) { @@ -157,7 +157,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: BackblazeB2Provider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with memory provider", func(t *testing.T) { @@ -169,7 +169,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: MemoryProvider, } - assert.NoError(t, cfg.ValidateWithContext(ctx)) + test.NoError(t, cfg.ValidateWithContext(ctx)) }) T.Run("with filesystem provider missing config", func(t *testing.T) { @@ -181,7 +181,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { Provider: FilesystemProvider, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) T.Run("with non-s3 provider having s3 config is invalid", func(t *testing.T) { @@ -194,7 +194,7 @@ func TestConfig_ValidateWithContext(T *testing.T) { S3Config: &S3Config{BucketName: t.Name()}, } - assert.Error(t, cfg.ValidateWithContext(ctx)) + test.Error(t, cfg.ValidateWithContext(ctx)) }) } @@ -212,8 +212,8 @@ func TestNewUploadManager(T *testing.T) { } x, err := NewUploadManager(ctx, l, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), cfg) - assert.NotNil(t, x) - assert.NoError(t, err) + test.NotNil(t, x) + test.NoError(t, err) }) T.Run("with nil config", func(t *testing.T) { @@ -223,8 +223,8 @@ func TestNewUploadManager(T *testing.T) { l := logging.NewNoopLogger() x, err := NewUploadManager(ctx, l, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), nil) - assert.Nil(t, x) - assert.Error(t, err) + test.Nil(t, x) + test.Error(t, err) }) T.Run("with invalid config", func(t *testing.T) { @@ -235,8 +235,8 @@ func TestNewUploadManager(T *testing.T) { cfg := &Config{} x, err := NewUploadManager(ctx, l, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), cfg) - assert.Nil(t, x) - assert.Error(t, err) + test.Nil(t, x) + test.Error(t, err) }) T.Run("with filesystem provider", func(t *testing.T) { @@ -253,8 +253,8 @@ func TestNewUploadManager(T *testing.T) { } x, err := NewUploadManager(ctx, l, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), cfg) - assert.NotNil(t, x) - assert.NoError(t, err) + test.NotNil(t, x) + test.NoError(t, err) }) T.Run("with bucket prefix", func(t *testing.T) { @@ -269,8 +269,8 @@ func TestNewUploadManager(T *testing.T) { } x, err := NewUploadManager(ctx, l, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), cfg) - assert.NotNil(t, x) - assert.NoError(t, err) + test.NotNil(t, x) + test.NoError(t, err) }) T.Run("with selectBucket error", func(t *testing.T) { @@ -285,8 +285,8 @@ func TestNewUploadManager(T *testing.T) { } x, err := NewUploadManager(ctx, l, tracing.NewNoopTracerProvider(), metrics.NewNoopMetricsProvider(), cfg) - assert.Nil(t, x) - assert.Error(t, err) + test.Nil(t, x) + test.Error(t, err) }) } @@ -305,7 +305,7 @@ func TestUploader_selectBucket(T *testing.T) { }, } - assert.NoError(t, u.selectBucket(ctx, cfg)) + test.NoError(t, u.selectBucket(ctx, cfg)) }) T.Run("s3 with nil config", func(t *testing.T) { @@ -318,7 +318,7 @@ func TestUploader_selectBucket(T *testing.T) { S3Config: nil, } - assert.Error(t, u.selectBucket(ctx, cfg)) + test.Error(t, u.selectBucket(ctx, cfg)) }) T.Run("memory provider", func(t *testing.T) { @@ -330,7 +330,7 @@ func TestUploader_selectBucket(T *testing.T) { Provider: MemoryProvider, } - assert.NoError(t, u.selectBucket(ctx, cfg)) + test.NoError(t, u.selectBucket(ctx, cfg)) }) T.Run("r2 happy path", func(t *testing.T) { @@ -348,7 +348,7 @@ func TestUploader_selectBucket(T *testing.T) { }, } - assert.NoError(t, u.selectBucket(ctx, cfg)) + test.NoError(t, u.selectBucket(ctx, cfg)) }) T.Run("r2 with nil config", func(t *testing.T) { @@ -361,7 +361,7 @@ func TestUploader_selectBucket(T *testing.T) { R2Config: nil, } - assert.Error(t, u.selectBucket(ctx, cfg)) + test.Error(t, u.selectBucket(ctx, cfg)) }) T.Run("backblaze b2 happy path", func(t *testing.T) { @@ -379,7 +379,7 @@ func TestUploader_selectBucket(T *testing.T) { }, } - assert.NoError(t, u.selectBucket(ctx, cfg)) + test.NoError(t, u.selectBucket(ctx, cfg)) }) T.Run("backblaze b2 with nil config", func(t *testing.T) { @@ -392,7 +392,7 @@ func TestUploader_selectBucket(T *testing.T) { BackblazeB2Config: nil, } - assert.Error(t, u.selectBucket(ctx, cfg)) + test.Error(t, u.selectBucket(ctx, cfg)) }) T.Run("filesystem happy path", func(t *testing.T) { @@ -409,7 +409,7 @@ func TestUploader_selectBucket(T *testing.T) { }, } - assert.NoError(t, u.selectBucket(ctx, cfg)) + test.NoError(t, u.selectBucket(ctx, cfg)) }) T.Run("filesystem with nil config", func(t *testing.T) { @@ -422,7 +422,7 @@ func TestUploader_selectBucket(T *testing.T) { FilesystemConfig: nil, } - assert.Error(t, u.selectBucket(ctx, cfg)) + test.Error(t, u.selectBucket(ctx, cfg)) }) T.Run("memory provider with bucket prefix", func(t *testing.T) { @@ -435,8 +435,8 @@ func TestUploader_selectBucket(T *testing.T) { BucketPrefix: "my-prefix/", } - assert.NoError(t, u.selectBucket(ctx, cfg)) - assert.NotNil(t, u.bucket) + test.NoError(t, u.selectBucket(ctx, cfg)) + test.NotNil(t, u.bucket) }) T.Run("unknown provider falls through to filesystem default", func(t *testing.T) { @@ -450,7 +450,7 @@ func TestUploader_selectBucket(T *testing.T) { FilesystemConfig: &FilesystemConfig{RootDirectory: tempDir}, } - assert.NoError(t, u.selectBucket(ctx, cfg)) + test.NoError(t, u.selectBucket(ctx, cfg)) }) T.Run("gcp provider fails without credentials", func(t *testing.T) { @@ -465,7 +465,7 @@ func TestUploader_selectBucket(T *testing.T) { }, } - assert.Error(t, u.selectBucket(ctx, cfg)) + test.Error(t, u.selectBucket(ctx, cfg)) }) T.Run("filesystem with invalid root directory", func(t *testing.T) { @@ -478,6 +478,6 @@ func TestUploader_selectBucket(T *testing.T) { FilesystemConfig: &FilesystemConfig{RootDirectory: string([]byte{0x00})}, } - assert.Error(t, u.selectBucket(ctx, cfg)) + test.Error(t, u.selectBucket(ctx, cfg)) }) } diff --git a/version/version_test.go b/version/version_test.go index d2f8999..750d1ca 100644 --- a/version/version_test.go +++ b/version/version_test.go @@ -3,7 +3,7 @@ package version import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/shoenig/test" ) func TestGet(T *testing.T) { //nolint:paralleltest // mutates package-level vars; subtests must run sequentially @@ -15,10 +15,10 @@ func TestGet(T *testing.T) { //nolint:paralleltest // mutates package-level vars }) info := Get() - assert.Equal(t, "unknown", info.Version) - assert.Equal(t, "unknown", info.CommitHash) - assert.Equal(t, "unknown", info.CommitTime) - assert.Equal(t, "unknown", info.BuildTime) + test.EqOp(t, "unknown", info.Version) + test.EqOp(t, "unknown", info.CommitHash) + test.EqOp(t, "unknown", info.CommitTime) + test.EqOp(t, "unknown", info.BuildTime) }) T.Run("returns set values when vars are populated", func(t *testing.T) { //nolint:paralleltest // mutates package-level vars; subtests must run sequentially @@ -32,9 +32,9 @@ func TestGet(T *testing.T) { //nolint:paralleltest // mutates package-level vars }) info := Get() - assert.Equal(t, "v1.2.3", info.Version) - assert.Equal(t, "abc123", info.CommitHash) - assert.Equal(t, "2026-01-01T00:00:00Z", info.CommitTime) - assert.Equal(t, "2026-01-02T00:00:00Z", info.BuildTime) + test.EqOp(t, "v1.2.3", info.Version) + test.EqOp(t, "abc123", info.CommitHash) + test.EqOp(t, "2026-01-01T00:00:00Z", info.CommitTime) + test.EqOp(t, "2026-01-02T00:00:00Z", info.BuildTime) }) } From 8833c66c5be632da13ddf85d6c6a7d8c9c5c8827 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Sat, 11 Apr 2026 17:26:19 -0500 Subject: [PATCH 08/12] ci: fix generated check --- .github/workflows/generated_files.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/generated_files.yaml b/.github/workflows/generated_files.yaml index e8e532b..7cc3fb8 100644 --- a/.github/workflows/generated_files.yaml +++ b/.github/workflows/generated_files.yaml @@ -31,7 +31,7 @@ jobs: cache-dependency-path: go.sum - name: Regenerate - run: make generate + run: make generate format - name: Verify no drift run: | From 917cead8ed6b5fce7376123770ab604b9998efea Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Sat, 11 Apr 2026 21:15:40 -0500 Subject: [PATCH 09/12] fix(messagequeue/redis): wait for SUBSCRIBE confirmation before returning consumer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `(*redis.Client).Subscribe(ctx, topic)` in go-redis v8 is asynchronous — it returns a *PubSub before the SUBSCRIBE command has reached the server. The go-redis source documents this explicitly in a doc comment at redis.go line 660: callers are supposed to call Receive() on the returned *PubSub to block until the server confirms the subscription, otherwise a publisher racing the subscriber will silently drop the first message (Redis pub/sub does not buffer for late subscribers). provideRedisConsumer was skipping that step, so Test_redisConsumer_Consume/with_error_handling_message would intermittently hang for the full 10-minute test deadline: SUBSCRIBE and PUBLISH raced, PUBLISH won, the message was dropped server-side, the handler never fired, and the bare <-errorsChan receive blocked forever. The "redis: discarding bad PubSub connection: EOF" log lines everyone was blaming on testcontainers are the internal health-check goroutine seeing its socket torn out when the container was killed at the deadline — symptom, not cause. Fix: - provideRedisConsumer now calls subscription.Receive(ctx) right after Subscribe and propagates any error. Its signature becomes (*redisConsumer, error). - The log line "subscribed to topic!" now fires *after* the server confirms, so it reflects reality. - (*consumerProvider).ProvideConsumer propagates the error and constructs the consumer outside the cache mutex — we don't want to serialize the SUBSCRIBE round-trip behind the lock. A re-check under the write lock protects against duplicate construction on concurrent callers. Behavior change worth calling out for bisect: ProvideConsumer is now fail-fast. If Redis is unreachable at provisioning time, the caller gets an error immediately instead of a silently-broken consumer whose messages would have been dropped anyway. Services that call ProvideConsumer during startup wiring will now fail startup on queue unavailability, which is strictly better than booting into a broken state. Test changes: - Test_consumerProvider_ProvideConsumer/standard and /hitting_cache were passing fake addresses (QueueAddresses: []string{t.Name()}) and only worked because Subscribe never dialed. Replaced with container-backed equivalents that actually exercise the SUBSCRIBE round-trip and verify the cache returns the same instance on a second call. - Test_consumerProvider_ProvideConsumer/with_empty_topic kept as-is — it returns before Subscribe is reached, fake address never dialed. - Test_redisConsumer_Consume's channels are now buffered and the error-handler subtest uses select+timeout on errorsChan. Strictly defense-in-depth now that the root race is fixed, but it means a future regression fails in seconds instead of ten minutes. Stress-tested: 20x Test_redisConsumer_Consume runs in 31s, 10x full-package runs in 17s, zero flakes. Co-Authored-By: Claude Opus 4.6 (1M context) --- messagequeue/redis/consumer.go | 30 ++++++++++-- messagequeue/redis/consumer_test.go | 73 +++++++++++++++++++---------- 2 files changed, 75 insertions(+), 28 deletions(-) diff --git a/messagequeue/redis/consumer.go b/messagequeue/redis/consumer.go index 4324c1e..13cc280 100644 --- a/messagequeue/redis/consumer.go +++ b/messagequeue/redis/consumer.go @@ -33,7 +33,7 @@ type ( } ) -func provideRedisConsumer(ctx context.Context, logger logging.Logger, tracerProvider tracing.TracerProvider, metricsProvider metrics.Provider, redisClient subscriptionProvider, topic string, handlerFunc func(context.Context, []byte) error) *redisConsumer { +func provideRedisConsumer(ctx context.Context, logger logging.Logger, tracerProvider tracing.TracerProvider, metricsProvider metrics.Provider, redisClient subscriptionProvider, topic string, handlerFunc func(context.Context, []byte) error) (*redisConsumer, error) { mp := metrics.EnsureMetricsProvider(metricsProvider) consumedCounter, err := mp.NewInt64Counter(fmt.Sprintf("%s_consumed", topic)) @@ -43,6 +43,14 @@ func provideRedisConsumer(ctx context.Context, logger logging.Logger, tracerProv subscription := redisClient.Subscribe(ctx, topic) + // Block until Redis confirms the SUBSCRIBE has been registered on the + // server. Without this, a publisher racing us would silently drop the + // first message — Redis pub/sub does not buffer for late subscribers. + // See go-redis's own Subscribe doc comment for the rationale. + if _, err = subscription.Receive(ctx); err != nil { + return nil, fmt.Errorf("confirming redis subscription to %q: %w", topic, err) + } + logger.Debug("subscribed to topic!") return &redisConsumer{ @@ -51,7 +59,7 @@ func provideRedisConsumer(ctx context.Context, logger logging.Logger, tracerProv logger: logging.EnsureLogger(logger), tracer: tracing.NewNamedTracer(tracerProvider, fmt.Sprintf("%s_consumer", topic)), consumedCounter: consumedCounter, - } + }, nil } // Consume reads messages and applies the handler to their payloads. @@ -129,13 +137,27 @@ func (p *consumerProvider) ProvideConsumer(ctx context.Context, topic string, ha return nil, ErrEmptyInputProvided } + p.consumerCacheMu.RLock() + if cachedPub, ok := p.consumerCache[topic]; ok { + p.consumerCacheMu.RUnlock() + return cachedPub, nil + } + p.consumerCacheMu.RUnlock() + + // Build the consumer outside the cache lock — provideRedisConsumer now + // does a network RTT waiting for SUBSCRIBE confirmation, and we don't + // want to serialize that behind the mutex. + c, err := provideRedisConsumer(ctx, logger, p.tracerProvider, p.metricsProvider, p.redisClient, topic, handlerFunc) + if err != nil { + return nil, err + } + p.consumerCacheMu.Lock() defer p.consumerCacheMu.Unlock() + // Re-check in case a concurrent caller beat us to it. if cachedPub, ok := p.consumerCache[topic]; ok { return cachedPub, nil } - - c := provideRedisConsumer(ctx, logger, p.tracerProvider, p.metricsProvider, p.redisClient, topic, handlerFunc) p.consumerCache[topic] = c return c, nil diff --git a/messagequeue/redis/consumer_test.go b/messagequeue/redis/consumer_test.go index 7c51021..7c702e6 100644 --- a/messagequeue/redis/consumer_test.go +++ b/messagequeue/redis/consumer_test.go @@ -60,8 +60,8 @@ func Test_redisConsumer_Consume(T *testing.T) { consumer := buildRedisBackedConsumer(t, cfg, t.Name(), hf) must.NotNil(t, consumer) - stopChan := make(chan bool) - errorsChan := make(chan error) + stopChan := make(chan bool, 1) + errorsChan := make(chan error, 1) go consumer.Consume(ctx, stopChan, errorsChan) publisher := buildRedisBackedPublisher(t, cfg, t.Name()) @@ -94,18 +94,25 @@ func Test_redisConsumer_Consume(T *testing.T) { consumer := buildRedisBackedConsumer(t, cfg, t.Name(), hf) must.NotNil(t, consumer) - stopChan := make(chan bool) - errorsChan := make(chan error) + stopChan := make(chan bool, 1) + errorsChan := make(chan error, 1) go consumer.Consume(ctx, stopChan, errorsChan) publisher := buildRedisBackedPublisher(t, cfg, t.Name()) must.NoError(t, publisher.Publish(ctx, []byte("blah"))) - receivedErr := <-errorsChan - test.Error(t, receivedErr) - test.ErrorIs(t, receivedErr, anticipatedError) + select { + case receivedErr := <-errorsChan: + test.Error(t, receivedErr) + test.ErrorIs(t, receivedErr, anticipatedError) + case <-time.After(10 * time.Second): + t.Fatal("timed out waiting for handler error on errorsChan") + } - stopChan <- true + select { + case stopChan <- true: + case <-time.After(time.Second): + } }) } @@ -115,17 +122,23 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { T.Run("standard", func(t *testing.T) { t.Parallel() - logger := logging.NewNoopLogger() - cfg := Config{ - QueueAddresses: []string{t.Name()}, + ctx := t.Context() + + cfg, containerShutdown, err := BuildContainerBackedRedisConfigForTest(t) + if err != nil { + t.Skipf("Skipping test due to container setup failure: %v", err) } + defer func() { + if containerShutdown != nil { + test.NoError(t, containerShutdown(ctx)) + } + }() - conPro := ProvideRedisConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) + conPro := ProvideRedisConsumerProvider(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, *cfg) must.NotNil(t, conPro) - ctx := t.Context() - - actual, err := conPro.ProvideConsumer(ctx, t.Name(), nil) + hf := func(context.Context, []byte) error { return nil } + actual, err := conPro.ProvideConsumer(ctx, t.Name(), hf) test.NoError(t, err) test.NotNil(t, actual) }) @@ -133,23 +146,34 @@ func Test_consumerProvider_ProvideConsumer(T *testing.T) { T.Run("hitting cache", func(t *testing.T) { t.Parallel() - logger := logging.NewNoopLogger() - cfg := Config{ - QueueAddresses: []string{t.Name()}, + ctx := t.Context() + + cfg, containerShutdown, err := BuildContainerBackedRedisConfigForTest(t) + if err != nil { + t.Skipf("Skipping test due to container setup failure: %v", err) } + defer func() { + if containerShutdown != nil { + test.NoError(t, containerShutdown(ctx)) + } + }() - conPro := ProvideRedisConsumerProvider(logger, tracing.NewNoopTracerProvider(), nil, cfg) + conPro := ProvideRedisConsumerProvider(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), nil, *cfg) must.NotNil(t, conPro) - ctx := t.Context() + hf := func(context.Context, []byte) error { return nil } - actual, err := conPro.ProvideConsumer(ctx, t.Name(), nil) + first, err := conPro.ProvideConsumer(ctx, t.Name(), hf) test.NoError(t, err) - test.NotNil(t, actual) + must.NotNil(t, first) - actual, err = conPro.ProvideConsumer(ctx, t.Name(), nil) + second, err := conPro.ProvideConsumer(ctx, t.Name(), hf) test.NoError(t, err) - test.NotNil(t, actual) + must.NotNil(t, second) + + // Second call for the same topic must return the exact same instance + // from the cache — no second SUBSCRIBE round-trip. + test.True(t, first == second) }) T.Run("with empty topic", func(t *testing.T) { @@ -184,5 +208,6 @@ func Test_provideRedisConsumer(T *testing.T) { test.Panic(t, func() { provideRedisConsumer(t.Context(), logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } From ab78fbd0019c156ea30f993a6926992b45a29a43 Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Sat, 11 Apr 2026 21:19:52 -0500 Subject: [PATCH 10/12] deps: migrate go-redis from v8 to v9 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Swap `github.com/go-redis/redis/v8` for `github.com/redis/go-redis/v9` (the module moved from the go-redis org to the redis org). v8 has been unmaintained since early 2023; v9 is the current maintained line. Pure mechanical port — this codebase uses none of the APIs v9 broke: - No custom `redis.Hook` implementations (v9 replaced `BeforeProcess`/ `AfterProcess` with `DialHook`/`ProcessHook`/`ProcessPipelineHook`). - No `redisotel` instrumentation (all tracing/metrics is wrapped at the platform level via our own `tracing.Tracer` and `metrics.Provider`). - No `WithContext()` calls (we already pass ctx as the first arg everywhere, which is the v9-mandatory style). - No hook-based command interception. Every go-redis API we actually touch is stable between v8.11.5 and v9.18.0: `Options`/`ClusterOptions` fields, `NewClient`/`NewClusterClient`, `Get`/`Set`/`Del`/`Ping`/`SetNX`/`Eval`/`Publish`/`Subscribe`/`Close`, `*Cmd`/`*BoolCmd`/`*IntCmd`/`*StatusCmd`/`*StringCmd` return types and their `.Result()`/`.Int64()`/`.Err()` accessors, `*PubSub.Receive`/ `.Channel`, `ChannelOption`, `redis.Nil` sentinel, `redis.ErrClosed`, and the mock command constructors `NewCmd`/`NewBoolCmd`/`NewIntCmd`/ `NewStatusCmd`/`NewStringCmd` plus their `SetVal`/`SetErr` methods. Touches four production packages — `messagequeue/redis`, `cache/redis`, `ratelimiting/redis`, `distributedlock/redis` — plus their tests. Also picks up a handful of small in-flight test assertions added to `publisher_test.go` as part of the ongoing testing-refactor work on this branch; they were mixed in the working tree. Verified: - `go build ./...` clean. - `go vet ./messagequeue/redis/... ./cache/redis/... ./ratelimiting/redis/... ./distributedlock/redis/...` clean. - Full test suite green (race, shuffle, failfast). - Stress test: 5x Test_redisConsumer_Consume passes on v9. Co-Authored-By: Claude Opus 4.6 (1M context) --- cache/redis/redis.go | 2 +- cache/redis/redis_test.go | 2 +- cache/redis/redisclient_mock_test.go | 2 +- distributedlock/redis/redis.go | 2 +- distributedlock/redis/redis_test.go | 2 +- go.mod | 3 ++- go.sum | 24 ++++++++++++------------ messagequeue/redis/consumer.go | 2 +- messagequeue/redis/publisher.go | 2 +- messagequeue/redis/publisher_test.go | 6 +++++- ratelimiting/redis/redis.go | 2 +- ratelimiting/redis/redis_test.go | 2 +- 12 files changed, 28 insertions(+), 23 deletions(-) diff --git a/cache/redis/redis.go b/cache/redis/redis.go index e498259..b6c9c31 100644 --- a/cache/redis/redis.go +++ b/cache/redis/redis.go @@ -16,7 +16,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" ) const name = "redis_cache" diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index d06f6b8..971f037 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -16,7 +16,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" "github.com/shoenig/test" "github.com/shoenig/test/must" rediscontainers "github.com/testcontainers/testcontainers-go/modules/redis" diff --git a/cache/redis/redisclient_mock_test.go b/cache/redis/redisclient_mock_test.go index 005a6d8..2572469 100644 --- a/cache/redis/redisclient_mock_test.go +++ b/cache/redis/redisclient_mock_test.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" ) // Ensure, that redisClientMock does implement redisClient. diff --git a/distributedlock/redis/redis.go b/distributedlock/redis/redis.go index 2add284..d7e93a1 100644 --- a/distributedlock/redis/redis.go +++ b/distributedlock/redis/redis.go @@ -15,7 +15,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" ) const serviceName = "redis_distributed_lock" diff --git a/distributedlock/redis/redis_test.go b/distributedlock/redis/redis_test.go index f04c728..5007607 100644 --- a/distributedlock/redis/redis_test.go +++ b/distributedlock/redis/redis_test.go @@ -17,7 +17,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" "github.com/shoenig/test" "github.com/shoenig/test/must" rediscontainers "github.com/testcontainers/testcontainers-go/modules/redis" diff --git a/go.mod b/go.mod index 2798d7c..24d9723 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,6 @@ require ( github.com/go-chi/cors v1.2.2 github.com/go-faker/faker/v4 v4.7.0 github.com/go-ozzo/ozzo-validation/v4 v4.3.0 - github.com/go-redis/redis/v8 v8.11.5 github.com/go-sql-driver/mysql v1.9.3 github.com/gogo/protobuf v1.3.2 github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 @@ -50,6 +49,7 @@ require ( github.com/open-feature/go-sdk-contrib/providers/launchdarkly v0.1.6 github.com/posthog/posthog-go v1.11.1 github.com/pusher/pusher-http-go/v5 v5.1.1 + github.com/redis/go-redis/v9 v9.18.0 github.com/resend/resend-go/v3 v3.2.0 github.com/riandyrn/otelchi v0.12.2 github.com/rs/xid v1.6.0 @@ -120,6 +120,7 @@ require ( github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/stretchr/testify v1.11.1 // indirect github.com/x448/float16 v0.8.4 // indirect + go.uber.org/atomic v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/term v0.41.0 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect diff --git a/go.sum b/go.sum index 8eef85c..2f88dcc 100644 --- a/go.sum +++ b/go.sum @@ -145,6 +145,10 @@ github.com/boombuler/barcode v1.1.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= github.com/brianvoe/gofakeit/v7 v7.14.1 h1:a7fe3fonbj0cW3wgl5VwIKfZtiH9C3cLnwcIXWT7sow= github.com/brianvoe/gofakeit/v7 v7.14.1/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cenk/backoff v2.2.1+incompatible h1:djdFT7f4gF2ttuzRKPbMOWgZajgesItGLwG5FTQKmmE= github.com/cenk/backoff v2.2.1+incompatible/go.mod h1:7FtoeaSnHoZnmZzz47cM35Y9nSW7tNyaidugnHTaFDE= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= @@ -233,8 +237,6 @@ github.com/felixge/httpsnoop v1.0.0/go.mod h1:3+D9sFq0ahK/JeJPhCBUV1xlf4/eIYrUQa github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.4.3-0.20170329110642-4da3e2cfbabc/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= -github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/garyburd/redigo v1.1.1-0.20170914051019-70e1b1943d4f/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY= @@ -280,8 +282,6 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= -github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= -github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -431,6 +431,8 @@ github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCy github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= +github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= +github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= @@ -545,8 +547,6 @@ github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOF github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= -github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= -github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj4EljqMiZsIcE09mmF8XsD5AYOJc= github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0= github.com/olekukonko/errors v1.2.0 h1:10Zcn4GeV59t/EGqJc8fUjtFT/FuUh5bTMzZ1XwmCRo= @@ -556,8 +556,6 @@ github.com/olekukonko/ll v0.1.7/go.mod h1:RPRC6UcscfFZgjo1nulkfMH5IM0QAYim0LfnMv github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/olekukonko/tablewriter v1.1.4 h1:ORUMI3dXbMnRlRggJX3+q7OzQFDdvgbN9nVWj1drm6I= github.com/olekukonko/tablewriter v1.1.4/go.mod h1:+kedxuyTtgoZLwif3P1Em4hARJs+mVnzKxmsCL/C5RY= -github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.27.2 h1:LzwLj0b89qtIy6SSASkzlNvX6WktqurSHwkk2ipF/Ns= github.com/onsi/ginkgo/v2 v2.27.2/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo= github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A= @@ -596,8 +594,8 @@ github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:Om github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/pusher/pusher-http-go/v5 v5.1.1 h1:ZLUGdLA8yXMvByafIkS47nvuXOHrYmlh4bsQvuZnYVQ= github.com/pusher/pusher-http-go/v5 v5.1.1/go.mod h1:Ibji4SGoUDtOy7CVRhCiEpgy+n5Xv6hSL/QqYOhmWW8= -github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= -github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/resend/resend-go/v3 v3.2.0 h1:jChLDFSLxKewNf6JEkxUyp/sJbaHBqd/NQfxCdXuVJk= @@ -732,6 +730,8 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.einride.tech/aip v0.83.0 h1:TI21IdeOnLTwZEJ3BxtImIZk6bsN2Q+sd0x99SLiQ+M= go.einride.tech/aip v0.83.0/go.mod h1:E8+wdTApA70odnpFzJgsGogHozC2JCIhFJBKPr8bVig= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= @@ -780,6 +780,8 @@ go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4Len go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= @@ -998,8 +1000,6 @@ gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/stretchr/testify.v1 v1.2.2 h1:yhQC6Uy5CqibAIlk1wlusa/MJ3iAN49/BsR/dCCKz3M= gopkg.in/stretchr/testify.v1 v1.2.2/go.mod h1:QI5V/q6UbPmuhtm10CaFZxED9NreB8PnFYN9JcR6TxU= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/messagequeue/redis/consumer.go b/messagequeue/redis/consumer.go index 13cc280..b2e655b 100644 --- a/messagequeue/redis/consumer.go +++ b/messagequeue/redis/consumer.go @@ -12,7 +12,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" ) type ( diff --git a/messagequeue/redis/publisher.go b/messagequeue/redis/publisher.go index 52b4fed..2cd7a0f 100644 --- a/messagequeue/redis/publisher.go +++ b/messagequeue/redis/publisher.go @@ -17,7 +17,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" ) var ( diff --git a/messagequeue/redis/publisher_test.go b/messagequeue/redis/publisher_test.go index 5608323..3dbbac7 100644 --- a/messagequeue/redis/publisher_test.go +++ b/messagequeue/redis/publisher_test.go @@ -13,7 +13,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" "github.com/shoenig/test" "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" @@ -306,6 +306,7 @@ func Test_provideRedisPublisher(T *testing.T) { test.Panic(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("panics when second NewInt64Counter fails", func(t *testing.T) { @@ -327,6 +328,7 @@ func Test_provideRedisPublisher(T *testing.T) { test.Panic(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("panics when NewFloat64Histogram fails", func(t *testing.T) { @@ -344,5 +346,7 @@ func Test_provideRedisPublisher(T *testing.T) { test.Panic(t, func() { provideRedisPublisher(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t") }) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/ratelimiting/redis/redis.go b/ratelimiting/redis/redis.go index cd64c67..2e123bd 100644 --- a/ratelimiting/redis/redis.go +++ b/ratelimiting/redis/redis.go @@ -10,7 +10,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/ratelimiting" validation "github.com/go-ozzo/ozzo-validation/v4" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" ) // Config configures a Redis-backed rate limiter. diff --git a/ratelimiting/redis/redis_test.go b/ratelimiting/redis/redis_test.go index e4755ba..fa5d99a 100644 --- a/ratelimiting/redis/redis_test.go +++ b/ratelimiting/redis/redis_test.go @@ -8,7 +8,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" - "github.com/go-redis/redis/v8" + "github.com/redis/go-redis/v9" "github.com/shoenig/test" "github.com/shoenig/test/must" "go.opentelemetry.io/otel/metric" From 2d5c171dbc141311535b121b4b3cd358ef8bfe5b Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Sat, 11 Apr 2026 22:25:28 -0500 Subject: [PATCH 11/12] fix: myriad test issues --- .scripts/coverage.sh | 3 + .scripts/pull_test_containers.sh | 76 +++++++++++++++++++ .scripts/test.sh | 6 +- .../tableaccess/access_manager_test.go | 2 +- distributedlock/postgres/postgres_test.go | 2 +- encoding/client_encoder_test.go | 3 +- encoding/io_writer_mock_test.go | 74 ++++++++++++++++++ encoding/mock_io_writer_test.go | 17 ----- encoding/writer_mock_gen.go | 11 +++ .../launchdarkly/feature_flag_manager_test.go | 20 +++++ .../posthog/feature_flag_manager_test.go | 5 ++ messagequeue/pubsub/consumer_test.go | 1 + messagequeue/pubsub/publisher_test.go | 4 + messagequeue/sqs/consumer_test.go | 1 + messagequeue/sqs/publisher_test.go | 4 + search/text/algolia/index_test.go | 25 ++++++ .../text/elasticsearch/elasticsearch_test.go | 62 +++++++++++++++ search/text/elasticsearch/index_test.go | 27 +++++++ search/vector/pgvector/pgvector_test.go | 9 ++- uploads/objectstorage/files_test.go | 6 ++ 20 files changed, 336 insertions(+), 22 deletions(-) create mode 100755 .scripts/pull_test_containers.sh create mode 100644 encoding/io_writer_mock_test.go delete mode 100644 encoding/mock_io_writer_test.go create mode 100644 encoding/writer_mock_gen.go diff --git a/.scripts/coverage.sh b/.scripts/coverage.sh index 0e89bd7..e29f31a 100755 --- a/.scripts/coverage.sh +++ b/.scripts/coverage.sh @@ -6,5 +6,8 @@ set -euo pipefail OUTPUT_FILE="${1:-coverage.out}" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +RUN_CONTAINER_TESTS="${RUN_CONTAINER_TESTS:-true}" "${SCRIPT_DIR}/pull_test_containers.sh" + # shellcheck disable=SC2086,SC2046 CGO_ENABLED=1 go test -shuffle=on -race -vet=all -failfast -covermode=atomic -coverprofile="${OUTPUT_FILE}" $(go list github.com/verygoodsoftwarenotvirus/platform/... | grep -Ev '(mock|testutils)') diff --git a/.scripts/pull_test_containers.sh b/.scripts/pull_test_containers.sh new file mode 100755 index 0000000..812788b --- /dev/null +++ b/.scripts/pull_test_containers.sh @@ -0,0 +1,76 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Pre-pull every Docker image used by testcontainers-go across the test suite. +# Pulling up front (in parallel) keeps first-run image-pull time from eating +# into per-test wait-strategy deadlines and lets subsequent runs hit a warm +# cache. +# +# Skipped when RUN_CONTAINER_TESTS is not "true" (nothing to pre-pull if the +# container tests aren't going to run) or when docker isn't on PATH. +# +# Usage: pull_test_containers.sh + +RUN_CONTAINER_TESTS="${RUN_CONTAINER_TESTS:-true}" +RUN_CONTAINER_TESTS_LC="$(printf '%s' "${RUN_CONTAINER_TESTS}" | tr '[:upper:]' '[:lower:]')" + +if [[ "${RUN_CONTAINER_TESTS_LC}" != "true" ]]; then + echo "pull_test_containers: RUN_CONTAINER_TESTS != true, skipping" + exit 0 +fi + +if ! command -v docker &>/dev/null; then + echo "pull_test_containers: docker not found on PATH, skipping" + exit 0 +fi + +# Keep this list in sync with image literals passed to testcontainers-go +# *.Run / ContainerRequest{Image:...} in *_test.go files. Pulling +# "redis:7-bullseye" also warms "docker.io/redis:7-bullseye" since they +# resolve to the same manifest. +IMAGES=( + "postgres:17-alpine" + "mariadb:11" + "redis:7-bullseye" + "pgvector/pgvector:pg17" + "qdrant/qdrant:v1.13.0" + "gcr.io/google.com/cloudsdktool/cloud-sdk:emulators" +) + +# elasticsearch:8.x crashes with SIGILL inside its bundled JDK on linux/arm64 +# under Docker Desktop, so TestElasticsearch_Container skips itself on arm64. +# Only pre-pull the image on hosts that will actually run the test. +ARCH="$(uname -m)" +if [[ "${ARCH}" != "arm64" && "${ARCH}" != "aarch64" ]]; then + IMAGES+=("elasticsearch:8.10.2") +fi + +echo "pull_test_containers: pulling ${#IMAGES[@]} images in parallel" + +pids=() +for img in "${IMAGES[@]}"; do + ( + if out=$(docker pull --quiet "$img" 2>&1); then + echo " ok $img" + else + echo " err $img" + echo " ${out//$'\n'/$'\n '}" + exit 1 + fi + ) & + pids+=($!) +done + +failed=0 +for pid in "${pids[@]}"; do + if ! wait "$pid"; then + failed=$((failed + 1)) + fi +done + +if (( failed > 0 )); then + echo "pull_test_containers: $failed image(s) failed to pull" >&2 + exit 1 +fi + +echo "pull_test_containers: done" diff --git a/.scripts/test.sh b/.scripts/test.sh index 05028ff..e40e84d 100755 --- a/.scripts/test.sh +++ b/.scripts/test.sh @@ -3,6 +3,10 @@ set -euo pipefail # Run tests # Usage: test.sh +RUN_CONTAINER_TESTS="${1:-true}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +RUN_CONTAINER_TESTS="${RUN_CONTAINER_TESTS}" "${SCRIPT_DIR}/pull_test_containers.sh" # shellcheck disable=SC2086,SC2046 -CGO_ENABLED=1 go test -shuffle=on -race -vet=all -failfast $(go list github.com/verygoodsoftwarenotvirus/platform/... | grep -Ev '(cmd|integration|mock|fakes|converters|utils|generated)') +CGO_ENABLED=1 RUN_CONTAINER_TESTS=${RUN_CONTAINER_TESTS} go test -shuffle=on -race -vet=all -failfast $(go list github.com/verygoodsoftwarenotvirus/platform/... | grep -Ev '(cmd|integration|mock|fakes|converters|utils|generated)') diff --git a/database/postgres/tableaccess/access_manager_test.go b/database/postgres/tableaccess/access_manager_test.go index 80f7b25..5af3d4f 100644 --- a/database/postgres/tableaccess/access_manager_test.go +++ b/database/postgres/tableaccess/access_manager_test.go @@ -23,7 +23,7 @@ import ( // TODO: lots of duplication with the upper postgres package const ( - defaultPostgresImage = "postgres:17" + defaultPostgresImage = "postgres:17-alpine" ) func reverseString(input string) string { diff --git a/distributedlock/postgres/postgres_test.go b/distributedlock/postgres/postgres_test.go index 1d90a3e..4e5eb38 100644 --- a/distributedlock/postgres/postgres_test.go +++ b/distributedlock/postgres/postgres_test.go @@ -27,7 +27,7 @@ import ( "go.opentelemetry.io/otel/metric" ) -const postgresImage = "postgres:16-alpine" +const postgresImage = "postgres:17-alpine" var runningContainerTests = strings.ToLower(os.Getenv("RUN_CONTAINER_TESTS")) == "true" diff --git a/encoding/client_encoder_test.go b/encoding/client_encoder_test.go index 6759c2a..09264a8 100644 --- a/encoding/client_encoder_test.go +++ b/encoding/client_encoder_test.go @@ -106,13 +106,14 @@ func Test_clientEncoder_Encode(T *testing.T) { ctx := t.Context() e := ProvideClientEncoder(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), ct) - mw := &mockWriter{ + mw := &ioWriterMock{ WriteFunc: func(_ []byte) (int, error) { return 0, errors.New("blah") }, } test.Error(t, e.Encode(ctx, mw, &example{Name: t.Name()})) + test.SliceLen(t, 1, mw.WriteCalls()) }) } diff --git a/encoding/io_writer_mock_test.go b/encoding/io_writer_mock_test.go new file mode 100644 index 0000000..9e5f4c9 --- /dev/null +++ b/encoding/io_writer_mock_test.go @@ -0,0 +1,74 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package encoding + +import ( + "sync" +) + +// Ensure, that ioWriterMock does implement ioWriter. +// If this is not the case, regenerate this file with moq. +var _ ioWriter = &ioWriterMock{} + +// ioWriterMock is a mock implementation of ioWriter. +// +// func TestSomethingThatUsesioWriter(t *testing.T) { +// +// // make and configure a mocked ioWriter +// mockedioWriter := &ioWriterMock{ +// WriteFunc: func(p []byte) (int, error) { +// panic("mock out the Write method") +// }, +// } +// +// // use mockedioWriter in code that requires ioWriter +// // and then make assertions. +// +// } +type ioWriterMock struct { + // WriteFunc mocks the Write method. + WriteFunc func(p []byte) (int, error) + + // calls tracks calls to the methods. + calls struct { + // Write holds details about calls to the Write method. + Write []struct { + // P is the p argument value. + P []byte + } + } + lockWrite sync.RWMutex +} + +// Write calls WriteFunc. +func (mock *ioWriterMock) Write(p []byte) (int, error) { + if mock.WriteFunc == nil { + panic("ioWriterMock.WriteFunc: method is nil but ioWriter.Write was just called") + } + callInfo := struct { + P []byte + }{ + P: p, + } + mock.lockWrite.Lock() + mock.calls.Write = append(mock.calls.Write, callInfo) + mock.lockWrite.Unlock() + return mock.WriteFunc(p) +} + +// WriteCalls gets all the calls that were made to Write. +// Check the length with: +// +// len(mockedioWriter.WriteCalls()) +func (mock *ioWriterMock) WriteCalls() []struct { + P []byte +} { + var calls []struct { + P []byte + } + mock.lockWrite.RLock() + calls = mock.calls.Write + mock.lockWrite.RUnlock() + return calls +} diff --git a/encoding/mock_io_writer_test.go b/encoding/mock_io_writer_test.go deleted file mode 100644 index 8fc6d39..0000000 --- a/encoding/mock_io_writer_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package encoding - -import ( - "io" -) - -var _ io.Writer = (*mockWriter)(nil) - -// mockWriter mocks an io.Writer. -type mockWriter struct { - WriteFunc func(p []byte) (int, error) -} - -// Write implements the io.Writer interface. -func (m *mockWriter) Write(p []byte) (int, error) { - return m.WriteFunc(p) -} diff --git a/encoding/writer_mock_gen.go b/encoding/writer_mock_gen.go new file mode 100644 index 0000000..3b30e01 --- /dev/null +++ b/encoding/writer_mock_gen.go @@ -0,0 +1,11 @@ +package encoding + +// ioWriter is a moq-friendly mirror of io.Writer. moq cannot generate mocks +// from stdlib interfaces directly, so we define a structurally-identical +// interface here purely so tests can mock Write calls. Any io.Writer satisfies +// this interface (and vice versa) via Go's structural typing. +type ioWriter interface { + Write(p []byte) (int, error) +} + +//go:generate go tool github.com/matryer/moq -out io_writer_mock_test.go -pkg encoding -rm -fmt goimports . ioWriter diff --git a/featureflags/launchdarkly/feature_flag_manager_test.go b/featureflags/launchdarkly/feature_flag_manager_test.go index 14a5862..79c1712 100644 --- a/featureflags/launchdarkly/feature_flag_manager_test.go +++ b/featureflags/launchdarkly/feature_flag_manager_test.go @@ -279,6 +279,9 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { result, err := ffm.CanUseFeature(ctx, "nonexistent-flag", evalCtx("user123")) test.Error(t, err) test.False(t, result) + test.SliceLen(t, 1, cb.CanProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) + test.SliceLen(t, 0, cb.SucceededCalls()) }) T.Run("with broken circuit", func(t *testing.T) { @@ -294,6 +297,7 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { result, err := ffm.CanUseFeature(ctx, "some-flag", evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.False(t, result) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } @@ -326,6 +330,9 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { result, err := ffm.GetStringValue(ctx, "nonexistent-flag", "fallback", evalCtx("user123")) test.Error(t, err) test.EqOp(t, "fallback", result) + test.SliceLen(t, 1, cb.CanProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) + test.SliceLen(t, 0, cb.SucceededCalls()) }) T.Run("with broken circuit", func(t *testing.T) { @@ -341,6 +348,7 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { result, err := ffm.GetStringValue(ctx, "some-flag", "fallback", evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.EqOp(t, "fallback", result) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } @@ -373,6 +381,9 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { result, err := ffm.GetInt64Value(ctx, "nonexistent-flag", int64(42), evalCtx("user123")) test.Error(t, err) test.EqOp(t, int64(42), result) + test.SliceLen(t, 1, cb.CanProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) + test.SliceLen(t, 0, cb.SucceededCalls()) }) T.Run("with broken circuit", func(t *testing.T) { @@ -388,6 +399,7 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { result, err := ffm.GetInt64Value(ctx, "some-flag", int64(42), evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.EqOp(t, int64(42), result) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } @@ -420,6 +432,9 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { result, err := ffm.GetFloat64Value(ctx, "nonexistent-flag", 3.14, evalCtx("user123")) test.Error(t, err) test.InDelta(t, 3.14, result, 1e-9) + test.SliceLen(t, 1, cb.CanProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) + test.SliceLen(t, 0, cb.SucceededCalls()) }) T.Run("with broken circuit", func(t *testing.T) { @@ -435,6 +450,7 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { result, err := ffm.GetFloat64Value(ctx, "some-flag", 3.14, evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.InDelta(t, 3.14, result, 1e-9) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } @@ -469,6 +485,9 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { result, err := ffm.GetObjectValue(ctx, "nonexistent-flag", def, evalCtx("user123")) test.Error(t, err) test.Eq[any](t, def, result) + test.SliceLen(t, 1, cb.CanProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) + test.SliceLen(t, 0, cb.SucceededCalls()) }) T.Run("with broken circuit", func(t *testing.T) { @@ -485,6 +504,7 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { result, err := ffm.GetObjectValue(ctx, "some-flag", def, evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.Eq[any](t, def, result) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } diff --git a/featureflags/posthog/feature_flag_manager_test.go b/featureflags/posthog/feature_flag_manager_test.go index 0bc272c..d66a769 100644 --- a/featureflags/posthog/feature_flag_manager_test.go +++ b/featureflags/posthog/feature_flag_manager_test.go @@ -265,6 +265,7 @@ func TestFeatureFlagManager_CanUseFeature(T *testing.T) { result, err := ffm.CanUseFeature(ctx, "some-flag", evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.False(t, result) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } @@ -306,6 +307,7 @@ func TestFeatureFlagManager_GetStringValue(T *testing.T) { result, err := ffm.GetStringValue(ctx, "some-flag", "fallback", evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.EqOp(t, "fallback", result) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } @@ -347,6 +349,7 @@ func TestFeatureFlagManager_GetInt64Value(T *testing.T) { result, err := ffm.GetInt64Value(ctx, "some-flag", int64(42), evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.EqOp(t, int64(42), result) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } @@ -388,6 +391,7 @@ func TestFeatureFlagManager_GetFloat64Value(T *testing.T) { result, err := ffm.GetFloat64Value(ctx, "some-flag", 3.14, evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.InDelta(t, 3.14, result, 1e-9) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } @@ -432,6 +436,7 @@ func TestFeatureFlagManager_GetObjectValue(T *testing.T) { result, err := ffm.GetObjectValue(ctx, "some-flag", def, evalCtx("user123")) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.Eq[any](t, def, result) + test.SliceLen(t, 1, cb.CanProceedCalls()) }) } diff --git a/messagequeue/pubsub/consumer_test.go b/messagequeue/pubsub/consumer_test.go index c2ad974..899f01a 100644 --- a/messagequeue/pubsub/consumer_test.go +++ b/messagequeue/pubsub/consumer_test.go @@ -142,6 +142,7 @@ func TestBuildPubSubConsumer(T *testing.T) { test.Panic(t, func() { buildPubSubConsumer(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/messagequeue/pubsub/publisher_test.go b/messagequeue/pubsub/publisher_test.go index 1dabcef..454d230 100644 --- a/messagequeue/pubsub/publisher_test.go +++ b/messagequeue/pubsub/publisher_test.go @@ -42,6 +42,7 @@ func TestBuildPubSubPublisher(T *testing.T) { test.Panic(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("panics when second NewInt64Counter fails", func(t *testing.T) { @@ -63,6 +64,7 @@ func TestBuildPubSubPublisher(T *testing.T) { test.Panic(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("panics when NewFloat64Histogram fails", func(t *testing.T) { @@ -80,6 +82,8 @@ func TestBuildPubSubPublisher(T *testing.T) { test.Panic(t, func() { buildPubSubPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/messagequeue/sqs/consumer_test.go b/messagequeue/sqs/consumer_test.go index 9cd2522..0f8bb0d 100644 --- a/messagequeue/sqs/consumer_test.go +++ b/messagequeue/sqs/consumer_test.go @@ -233,5 +233,6 @@ func Test_provideSQSConsumer(T *testing.T) { test.Panic(t, func() { provideSQSConsumer(logging.NewNoopLogger(), tracing.NewNoopTracerProvider(), mp, nil, "t", nil) }) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) } diff --git a/messagequeue/sqs/publisher_test.go b/messagequeue/sqs/publisher_test.go index 169975a..cf500f2 100644 --- a/messagequeue/sqs/publisher_test.go +++ b/messagequeue/sqs/publisher_test.go @@ -281,6 +281,7 @@ func Test_provideSQSPublisher(T *testing.T) { test.Panic(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("panics when second NewInt64Counter fails", func(t *testing.T) { @@ -302,6 +303,7 @@ func Test_provideSQSPublisher(T *testing.T) { test.Panic(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("panics when NewFloat64Histogram fails", func(t *testing.T) { @@ -319,5 +321,7 @@ func Test_provideSQSPublisher(T *testing.T) { test.Panic(t, func() { provideSQSPublisher(logging.NewNoopLogger(), nil, tracing.NewNoopTracerProvider(), mp, "t") }) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/search/text/algolia/index_test.go b/search/text/algolia/index_test.go index 1fb35d1..15ec988 100644 --- a/search/text/algolia/index_test.go +++ b/search/text/algolia/index_test.go @@ -124,6 +124,7 @@ func TestIndexManager_Index(T *testing.T) { err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) test.Error(t, err) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with unmarshalable value", func(t *testing.T) { @@ -169,6 +170,7 @@ func TestIndexManager_Index(T *testing.T) { err := im.Index(context.Background(), "123", map[string]string{"id": "123", "name": "example"}) test.NoError(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) } @@ -188,6 +190,7 @@ func TestIndexManager_Search(T *testing.T) { test.Error(t, err) test.Nil(t, results) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with empty query", func(t *testing.T) { @@ -203,6 +206,7 @@ func TestIndexManager_Search(T *testing.T) { test.Error(t, err) test.Nil(t, results) test.ErrorIs(t, err, ErrEmptyQueryProvided) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with valid query but invalid credentials", func(t *testing.T) { @@ -218,6 +222,8 @@ func TestIndexManager_Search(T *testing.T) { results, err := im.Search(context.Background(), "test query") test.Error(t, err) test.Nil(t, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) T.Run("with successful search results", func(t *testing.T) { @@ -239,6 +245,8 @@ func TestIndexManager_Search(T *testing.T) { test.NoError(t, err) test.NotNil(t, results) test.SliceLen(t, 1, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("with empty search results", func(t *testing.T) { @@ -260,6 +268,8 @@ func TestIndexManager_Search(T *testing.T) { test.NoError(t, err) test.NotNil(t, results) test.SliceEmpty(t, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("with multiple search results", func(t *testing.T) { @@ -286,6 +296,8 @@ func TestIndexManager_Search(T *testing.T) { test.EqOp(t, "second", results[1].Name) test.EqOp(t, "ghi", results[2].ID) test.EqOp(t, "third", results[2].Name) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("when unmarshalling search result fails", func(t *testing.T) { @@ -305,6 +317,7 @@ func TestIndexManager_Search(T *testing.T) { results, err := im.Search(context.Background(), "test query") test.Error(t, err) test.Nil(t, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with successful search results without objectID", func(t *testing.T) { @@ -326,6 +339,8 @@ func TestIndexManager_Search(T *testing.T) { test.NoError(t, err) test.NotNil(t, results) test.SliceLen(t, 1, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) } @@ -344,6 +359,7 @@ func TestIndexManager_Delete(T *testing.T) { err := im.Delete(context.Background(), "id") test.Error(t, err) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with invalid credentials", func(t *testing.T) { @@ -358,6 +374,8 @@ func TestIndexManager_Delete(T *testing.T) { err := im.Delete(context.Background(), "some-id") test.Error(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) T.Run("with successful deletion", func(t *testing.T) { @@ -377,6 +395,8 @@ func TestIndexManager_Delete(T *testing.T) { err := im.Delete(context.Background(), "some-id") test.NoError(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) } @@ -395,6 +415,7 @@ func TestIndexManager_Wipe(T *testing.T) { err := im.Wipe(context.Background()) test.Error(t, err) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with invalid credentials", func(t *testing.T) { @@ -409,6 +430,8 @@ func TestIndexManager_Wipe(T *testing.T) { err := im.Wipe(context.Background()) test.Error(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) T.Run("with successful wipe", func(t *testing.T) { @@ -428,5 +451,7 @@ func TestIndexManager_Wipe(T *testing.T) { err := im.Wipe(context.Background()) test.NoError(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) } diff --git a/search/text/elasticsearch/elasticsearch_test.go b/search/text/elasticsearch/elasticsearch_test.go index 7eb88ce..958e2e0 100644 --- a/search/text/elasticsearch/elasticsearch_test.go +++ b/search/text/elasticsearch/elasticsearch_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "os" + "runtime" "strings" "testing" "time" @@ -19,7 +20,9 @@ import ( "github.com/shoenig/test" "github.com/shoenig/test/must" + "github.com/testcontainers/testcontainers-go" elasticsearchcontainers "github.com/testcontainers/testcontainers-go/modules/elasticsearch" + "github.com/testcontainers/testcontainers-go/wait" ) var runningContainerTests = strings.ToLower(os.Getenv("RUN_CONTAINER_TESTS")) == "true" @@ -33,6 +36,39 @@ type esTestInfra struct { shutdown func(context.Context) error } +// extendWaitStrategyTimeout returns a PostCreates lifecycle hook that extends +// the timeouts of the elasticsearch module's bundled wait strategies. The +// module hard-codes a 60s deadline on its MultiStrategy and each inner +// FileStrategy/HTTPStrategy also defaults to 60s, which is too tight on a cold +// start (image pull + ES auto-config + cert generation) on a busy host. We +// have to mutate the strategies after creation because the module appends its +// configureWaitFor customizer after any user opts, leaving no other override +// hook. Failure to assert the expected types is loud rather than silent so a +// future testcontainers refactor surfaces immediately instead of regressing to +// a flaky 60s timeout. +func extendWaitStrategyTimeout(timeout time.Duration) testcontainers.ContainerHook { + return func(_ context.Context, c testcontainers.Container) error { + dc, ok := c.(*testcontainers.DockerContainer) + if !ok { + return fmt.Errorf("extendWaitStrategyTimeout: unexpected container type %T", c) + } + ms, ok := dc.WaitingFor.(*wait.MultiStrategy) + if !ok { + return fmt.Errorf("extendWaitStrategyTimeout: unexpected wait strategy type %T", dc.WaitingFor) + } + ms.WithDeadline(timeout) + for _, s := range ms.Strategies { + switch w := s.(type) { + case *wait.FileStrategy: + w.WithStartupTimeout(timeout) + case *wait.HTTPStrategy: + w.WithStartupTimeout(timeout) + } + } + return nil + } +} + func buildEsTestInfra(t *testing.T) *esTestInfra { t.Helper() @@ -40,6 +76,11 @@ func buildEsTestInfra(t *testing.T) *esTestInfra { t.Context(), "elasticsearch:8.10.2", elasticsearchcontainers.WithPassword("arbitraryPassword"), + testcontainers.WithAdditionalLifecycleHooks(testcontainers.ContainerLifecycleHooks{ + PostCreates: []testcontainers.ContainerHook{ + extendWaitStrategyTimeout(5 * time.Minute), + }, + }), ) must.NoError(t, err) must.NotNil(t, elasticsearchContainer) @@ -69,6 +110,14 @@ func TestElasticsearch_Container(T *testing.T) { T.SkipNow() } + // The elasticsearch:8.x images crash with SIGILL inside the bundled JDK + // when run under linux/arm64 on Docker Desktop for Mac, so the cert wait + // strategy times out and the suite flakes. Skip until ES ships a JDK + // that tolerates this host. + if runtime.GOARCH == "arm64" { + T.Skip("elasticsearch JDK crashes on linux/arm64 under Docker Desktop; skipping") + } + infra := buildEsTestInfra(T) T.Cleanup(func() { _ = infra.shutdown(context.Background()) }) @@ -336,6 +385,7 @@ func TestIndexManager_ensureIndices_CircuitBroken(T *testing.T) { err := im.ensureIndices(context.Background()) test.Error(t, err) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with unreachable server", func(t *testing.T) { @@ -350,6 +400,8 @@ func TestIndexManager_ensureIndices_CircuitBroken(T *testing.T) { err := im.ensureIndices(context.Background()) test.Error(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -378,6 +430,8 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { err := im.ensureIndices(context.Background()) test.NoError(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("index does not exist and create succeeds", func(t *testing.T) { @@ -407,6 +461,8 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { err := im.ensureIndices(context.Background()) test.NoError(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("index does not exist and create fails", func(t *testing.T) { @@ -440,6 +496,8 @@ func TestIndexManager_ensureIndices_Unit(T *testing.T) { err := im.ensureIndices(context.Background()) test.Error(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -555,6 +613,8 @@ func TestProvideIndexManager_Unit(T *testing.T) { im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) test.NoError(t, err) test.NotNil(t, im) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("fails when ensureIndices fails", func(t *testing.T) { @@ -605,5 +665,7 @@ func TestProvideIndexManager_Unit(T *testing.T) { im, err := ProvideIndexManager[example](context.Background(), logger, tracerProvider, cfg, "test", cb) test.Error(t, err) test.Nil(t, im) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } diff --git a/search/text/elasticsearch/index_test.go b/search/text/elasticsearch/index_test.go index 7a0d620..7e783ef 100644 --- a/search/text/elasticsearch/index_test.go +++ b/search/text/elasticsearch/index_test.go @@ -79,6 +79,7 @@ func TestIndexManager_Index_CircuitBroken(T *testing.T) { err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) test.Error(t, err) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with unmarshalable value", func(t *testing.T) { @@ -92,6 +93,7 @@ func TestIndexManager_Index_CircuitBroken(T *testing.T) { err := im.Index(context.Background(), "id", make(chan int)) test.Error(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with unreachable server", func(t *testing.T) { @@ -106,6 +108,8 @@ func TestIndexManager_Index_CircuitBroken(T *testing.T) { err := im.Index(context.Background(), "id", map[string]string{"id": "test"}) test.Error(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -132,6 +136,8 @@ func TestIndexManager_Index_Unit(T *testing.T) { err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) test.NoError(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("with non-success status code", func(t *testing.T) { @@ -154,6 +160,8 @@ func TestIndexManager_Index_Unit(T *testing.T) { err := im.Index(context.Background(), "123", &example{ID: "123", Name: "test"}) test.Error(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -173,6 +181,7 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { test.Error(t, err) test.Nil(t, results) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with empty query", func(t *testing.T) { @@ -188,6 +197,7 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { test.Error(t, err) test.Nil(t, results) test.ErrorIs(t, err, ErrEmptyQueryProvided) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with unreachable server", func(t *testing.T) { @@ -203,6 +213,8 @@ func TestIndexManager_Search_CircuitBroken(T *testing.T) { results, err := im.Search(context.Background(), "test query") test.Error(t, err) test.Nil(t, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -232,6 +244,8 @@ func TestIndexManager_Search_Unit(T *testing.T) { must.SliceLen(t, 1, results) test.EqOp(t, "123", results[0].ID) test.EqOp(t, "test", results[0].Name) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) T.Run("with error response", func(t *testing.T) { @@ -259,6 +273,8 @@ func TestIndexManager_Search_Unit(T *testing.T) { results, err := im.Search(context.Background(), "test") test.NoError(t, err) test.Nil(t, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) T.Run("with invalid JSON in success response", func(t *testing.T) { @@ -284,6 +300,8 @@ func TestIndexManager_Search_Unit(T *testing.T) { results, err := im.Search(context.Background(), "test") test.NoError(t, err) test.Nil(t, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -313,6 +331,8 @@ func TestIndexManager_Search_ErrorResponseDecodeFailure_Unit(T *testing.T) { results, err := im.Search(context.Background(), "test") test.NoError(t, err) test.Nil(t, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -342,6 +362,8 @@ func TestIndexManager_Search_SourceUnmarshalError_Unit(T *testing.T) { results, err := im.Search(context.Background(), "test") test.NoError(t, err) test.Nil(t, results) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -360,6 +382,7 @@ func TestIndexManager_Delete_CircuitBroken(T *testing.T) { err := im.Delete(context.Background(), "id") test.Error(t, err) test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with unreachable server", func(t *testing.T) { @@ -374,6 +397,8 @@ func TestIndexManager_Delete_CircuitBroken(T *testing.T) { err := im.Delete(context.Background(), "some-id") test.Error(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) } @@ -400,6 +425,8 @@ func TestIndexManager_Delete_Unit(T *testing.T) { err := im.Delete(context.Background(), "123") test.NoError(t, err) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) } diff --git a/search/vector/pgvector/pgvector_test.go b/search/vector/pgvector/pgvector_test.go index 759a47f..8ad7393 100644 --- a/search/vector/pgvector/pgvector_test.go +++ b/search/vector/pgvector/pgvector_test.go @@ -48,7 +48,7 @@ func newCounterProviderMock(t *testing.T, results map[string]counterResult) *moc } } -const pgvectorImage = "pgvector/pgvector:pg16" +const pgvectorImage = "pgvector/pgvector:pg17" var runningContainerTests = strings.ToLower(os.Getenv("RUN_CONTAINER_TESTS")) == "true" @@ -168,6 +168,7 @@ func TestProvideIndex(T *testing.T) { _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) must.Error(t, err) + test.SliceLen(t, 1, mp.NewInt64CounterCalls()) }) T.Run("error creating delete counter", func(t *testing.T) { @@ -180,6 +181,7 @@ func TestProvideIndex(T *testing.T) { _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) must.Error(t, err) + test.SliceLen(t, 2, mp.NewInt64CounterCalls()) }) T.Run("error creating wipe counter", func(t *testing.T) { @@ -193,6 +195,7 @@ func TestProvideIndex(T *testing.T) { _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) must.Error(t, err) + test.SliceLen(t, 3, mp.NewInt64CounterCalls()) }) T.Run("error creating query counter", func(t *testing.T) { @@ -207,6 +210,7 @@ func TestProvideIndex(T *testing.T) { _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) must.Error(t, err) + test.SliceLen(t, 4, mp.NewInt64CounterCalls()) }) T.Run("error creating error counter", func(t *testing.T) { @@ -222,6 +226,7 @@ func TestProvideIndex(T *testing.T) { _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) must.Error(t, err) + test.SliceLen(t, 5, mp.NewInt64CounterCalls()) }) T.Run("error creating latency histogram", func(t *testing.T) { @@ -238,6 +243,8 @@ func TestProvideIndex(T *testing.T) { _, err := ProvideIndex[doc](t.Context(), nil, nil, mp, &Config{Dimension: 3, Metric: vectorsearch.DistanceCosine}, &testDBClient{}, "idx", cbnoop.NewCircuitBreaker()) must.Error(t, err) + test.SliceLen(t, 5, mp.NewInt64CounterCalls()) + test.SliceLen(t, 1, mp.NewFloat64HistogramCalls()) }) } diff --git a/uploads/objectstorage/files_test.go b/uploads/objectstorage/files_test.go index 7181abc..9175c1c 100644 --- a/uploads/objectstorage/files_test.go +++ b/uploads/objectstorage/files_test.go @@ -117,6 +117,7 @@ func TestUploader_ReadFile(T *testing.T) { x, err := u.ReadFile(ctx, "anything.txt") test.ErrorIs(t, err, circuitbreaking.ErrCircuitBroken) test.Nil(t, x) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with mock circuit breaker on successful read", func(t *testing.T) { @@ -150,6 +151,8 @@ func TestUploader_ReadFile(T *testing.T) { x, err := u.ReadFile(ctx, exampleFilename) test.NoError(t, err) test.Eq(t, expectedContent, x) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.SucceededCalls()) }) } @@ -199,6 +202,7 @@ func TestUploader_SaveFile(T *testing.T) { } test.ErrorIs(t, u.SaveFile(ctx, "test_file.txt", []byte(t.Name())), circuitbreaking.ErrCircuitBroken) + test.SliceLen(t, 1, cb.CannotProceedCalls()) }) T.Run("with write error", func(t *testing.T) { @@ -228,6 +232,8 @@ func TestUploader_SaveFile(T *testing.T) { } test.Error(t, u.SaveFile(ctx, "test_file.txt", []byte(t.Name()))) + test.SliceLen(t, 1, cb.CannotProceedCalls()) + test.SliceLen(t, 1, cb.FailedCalls()) }) T.Run("can be read back after save", func(t *testing.T) { From 06efa546e050e2a3575676e6dd63806841c34c3a Mon Sep 17 00:00:00 2001 From: Jeffrey D <1289344+verygoodsoftwarenotvirus@users.noreply.github.com> Date: Sat, 11 Apr 2026 23:03:05 -0500 Subject: [PATCH 12/12] test: make container retries better --- cache/redis/redis_test.go | 29 +++--- .../mysql/tableaccess/access_manager_test.go | 10 +-- .../tableaccess/access_manager_test.go | 10 +-- distributedlock/postgres/postgres_test.go | 19 ++-- distributedlock/redis/redis_test.go | 20 ++++- messagequeue/pubsub/consumer_test.go | 15 ++-- messagequeue/redis/test_helpers_test.go | 19 ++-- .../text/elasticsearch/elasticsearch_test.go | 55 ++++++++---- search/vector/pgvector/pgvector_test.go | 19 ++-- search/vector/qdrant/qdrant_test.go | 5 +- testutils/containers/containers.go | 51 +++++++++++ testutils/containers/containers_test.go | 88 +++++++++++++++++++ 12 files changed, 264 insertions(+), 76 deletions(-) create mode 100644 testutils/containers/containers.go create mode 100644 testutils/containers/containers_test.go diff --git a/cache/redis/redis_test.go b/cache/redis/redis_test.go index 971f037..a977f62 100644 --- a/cache/redis/redis_test.go +++ b/cache/redis/redis_test.go @@ -15,11 +15,14 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" "github.com/redis/go-redis/v9" "github.com/shoenig/test" "github.com/shoenig/test/must" + "github.com/testcontainers/testcontainers-go" rediscontainers "github.com/testcontainers/testcontainers-go/modules/redis" + "github.com/testcontainers/testcontainers-go/wait" "go.opentelemetry.io/otel/metric" ) @@ -107,19 +110,23 @@ func newCounterProviderMock(t *testing.T, results map[string]counterResult) *moc func buildContainerBackedRedisConfig(t *testing.T) (config *Config, shutdownFunction func(context.Context) error) { t.Helper() - // Use a dedicated context that won't be cancelled for the container lifecycle containerCtx := t.Context() - redisContainer, err := rediscontainers.Run(containerCtx, - redisImage, - rediscontainers.WithLogLevel(rediscontainers.LogLevelNotice), - ) - if err != nil { - panic(err) - } - - // Wait a small amount to ensure container is fully ready - time.Sleep(100 * time.Millisecond) + // Explicitly wait for both the TCP port and the "Ready to accept connections" + // log line so we don't race the redis-server bootstrap. The module's default + // wait strategy is implementation-defined, so pin it here for predictability. + redisContainer, err := containers.StartWithRetry(containerCtx, func(ctx context.Context) (*rediscontainers.RedisContainer, error) { + return rediscontainers.Run(ctx, + redisImage, + rediscontainers.WithLogLevel(rediscontainers.LogLevelNotice), + testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForAll( + wait.ForListeningPort("6379/tcp"), + wait.ForLog("Ready to accept connections"), + )), + ) + }) + must.NoError(t, err) + must.NotNil(t, redisContainer) redisAddress, err := redisContainer.ConnectionString(containerCtx) must.NoError(t, err) diff --git a/database/mysql/tableaccess/access_manager_test.go b/database/mysql/tableaccess/access_manager_test.go index 9ed2cbb..9b99623 100644 --- a/database/mysql/tableaccess/access_manager_test.go +++ b/database/mysql/tableaccess/access_manager_test.go @@ -10,7 +10,7 @@ import ( "time" "github.com/verygoodsoftwarenotvirus/platform/v5/pointer" - "github.com/verygoodsoftwarenotvirus/platform/v5/retry" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" _ "github.com/go-sql-driver/mysql" "github.com/shoenig/test" @@ -66,11 +66,8 @@ func buildDatabaseConnectionForTest(t *testing.T, ctx context.Context) (*sql.DB, dbPassword := reverseString(dbUsername) dbName := splitReverseConcat(dbUsername) - var container *mysqlcontainers.MySQLContainer - policy := retry.NewExponentialBackoffPolicy(retry.Config{MaxAttempts: 5, InitialDelay: 1, UseJitter: false}) - err := policy.Execute(ctx, func(ctx context.Context) error { - var containerErr error - container, containerErr = mysqlcontainers.Run( + container, err := containers.StartWithRetry(ctx, func(ctx context.Context) (*mysqlcontainers.MySQLContainer, error) { + return mysqlcontainers.Run( ctx, defaultMySQLImage, mysqlcontainers.WithDatabase(dbName), @@ -78,7 +75,6 @@ func buildDatabaseConnectionForTest(t *testing.T, ctx context.Context) (*sql.DB, mysqlcontainers.WithPassword(dbPassword), testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForLog("ready for connections").WithOccurrence(2)), ) - return containerErr }) must.NoError(t, err) must.NotNil(t, container) diff --git a/database/postgres/tableaccess/access_manager_test.go b/database/postgres/tableaccess/access_manager_test.go index 5af3d4f..8c6c62b 100644 --- a/database/postgres/tableaccess/access_manager_test.go +++ b/database/postgres/tableaccess/access_manager_test.go @@ -10,7 +10,7 @@ import ( "time" "github.com/verygoodsoftwarenotvirus/platform/v5/pointer" - "github.com/verygoodsoftwarenotvirus/platform/v5/retry" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" _ "github.com/jackc/pgx/v5/stdlib" "github.com/shoenig/test" @@ -83,11 +83,8 @@ func buildDatabaseConnectionForTest(t *testing.T, ctx context.Context) (*sql.DB, dbUsername := fmt.Sprintf("%d", hashStringToNumber(t.Name())) - var container *postgres.PostgresContainer - policy := retry.NewExponentialBackoffPolicy(retry.Config{MaxAttempts: 5, InitialDelay: 1, UseJitter: false}) - err := policy.Execute(ctx, func(ctx context.Context) error { - var containerErr error - container, containerErr = postgres.Run( + container, err := containers.StartWithRetry(ctx, func(ctx context.Context) (*postgres.PostgresContainer, error) { + return postgres.Run( ctx, defaultPostgresImage, postgres.WithDatabase(splitReverseConcat(dbUsername)), @@ -95,7 +92,6 @@ func buildDatabaseConnectionForTest(t *testing.T, ctx context.Context) (*sql.DB, postgres.WithPassword(reverseString(dbUsername)), testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForLog("database system is ready to accept connections").WithOccurrence(2)), ) - return containerErr }) must.NoError(t, err) must.NotNil(t, container) diff --git a/distributedlock/postgres/postgres_test.go b/distributedlock/postgres/postgres_test.go index 4e5eb38..a7f07dc 100644 --- a/distributedlock/postgres/postgres_test.go +++ b/distributedlock/postgres/postgres_test.go @@ -16,6 +16,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/distributedlock" "github.com/verygoodsoftwarenotvirus/platform/v5/identifiers" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" "github.com/DATA-DOG/go-sqlmock" _ "github.com/jackc/pgx/v5/stdlib" @@ -51,14 +52,16 @@ func buildContainerBackedPostgres(t *testing.T) (client *testDBClient, shutdown t.Helper() ctx := t.Context() - container, err := postgrescontainer.Run( - ctx, - postgresImage, - postgrescontainer.WithDatabase("locktest"), - postgrescontainer.WithUsername("locktest"), - postgrescontainer.WithPassword("locktest"), - testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForLog("database system is ready to accept connections").WithOccurrence(2)), - ) + container, err := containers.StartWithRetry(ctx, func(ctx context.Context) (*postgrescontainer.PostgresContainer, error) { + return postgrescontainer.Run( + ctx, + postgresImage, + postgrescontainer.WithDatabase("locktest"), + postgrescontainer.WithUsername("locktest"), + postgrescontainer.WithPassword("locktest"), + testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForLog("database system is ready to accept connections").WithOccurrence(2)), + ) + }) must.NoError(t, err) must.NotNil(t, container) diff --git a/distributedlock/redis/redis_test.go b/distributedlock/redis/redis_test.go index 5007607..ced1b65 100644 --- a/distributedlock/redis/redis_test.go +++ b/distributedlock/redis/redis_test.go @@ -16,11 +16,14 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" "github.com/redis/go-redis/v9" "github.com/shoenig/test" "github.com/shoenig/test/must" + "github.com/testcontainers/testcontainers-go" rediscontainers "github.com/testcontainers/testcontainers-go/modules/redis" + "github.com/testcontainers/testcontainers-go/wait" "go.opentelemetry.io/otel/metric" ) @@ -32,10 +35,19 @@ func buildContainerBackedRedisConfig(t *testing.T) (cfg *Config, shutdown func(c t.Helper() ctx := t.Context() - container, err := rediscontainers.Run(ctx, - redisImage, - rediscontainers.WithLogLevel(rediscontainers.LogLevelNotice), - ) + // Explicitly wait for both the TCP port and the "Ready to accept connections" + // log line so we don't race the redis-server bootstrap. The module's default + // wait strategy is implementation-defined, so pin it here for predictability. + container, err := containers.StartWithRetry(ctx, func(ctx context.Context) (*rediscontainers.RedisContainer, error) { + return rediscontainers.Run(ctx, + redisImage, + rediscontainers.WithLogLevel(rediscontainers.LogLevelNotice), + testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForAll( + wait.ForListeningPort("6379/tcp"), + wait.ForLog("Ready to accept connections"), + )), + ) + }) must.NoError(t, err) must.NotNil(t, container) diff --git a/messagequeue/pubsub/consumer_test.go b/messagequeue/pubsub/consumer_test.go index 899f01a..427bba9 100644 --- a/messagequeue/pubsub/consumer_test.go +++ b/messagequeue/pubsub/consumer_test.go @@ -17,6 +17,7 @@ import ( mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" "github.com/verygoodsoftwarenotvirus/platform/v5/random" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" "cloud.google.com/go/pubsub/v2" "cloud.google.com/go/pubsub/v2/apiv1/pubsubpb" @@ -30,6 +31,8 @@ import ( "google.golang.org/grpc/credentials/insecure" ) +const pubsubEmulatorImage = "gcr.io/google.com/cloudsdktool/cloud-sdk:emulators" + var runningContainerTests = strings.ToLower(os.Getenv("RUN_CONTAINER_TESTS")) == "true" type pubsubTestInfra struct { @@ -51,11 +54,13 @@ func buildPubSubTestInfra(t *testing.T) *pubsubTestInfra { must.NoError(t, err) projectID := "project-" + randomID - pubsubContainer, err := tcpubsub.Run( - ctx, - "gcr.io/google.com/cloudsdktool/cloud-sdk:emulators", - tcpubsub.WithProjectID(projectID), - ) + pubsubContainer, err := containers.StartWithRetry(ctx, func(ctx context.Context) (*tcpubsub.Container, error) { + return tcpubsub.Run( + ctx, + pubsubEmulatorImage, + tcpubsub.WithProjectID(projectID), + ) + }) must.NoError(t, err) must.NotNil(t, pubsubContainer) diff --git a/messagequeue/redis/test_helpers_test.go b/messagequeue/redis/test_helpers_test.go index 3515ea4..19a73ac 100644 --- a/messagequeue/redis/test_helpers_test.go +++ b/messagequeue/redis/test_helpers_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" + "github.com/testcontainers/testcontainers-go" rediscontainers "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/testcontainers/testcontainers-go/wait" @@ -30,12 +32,17 @@ func BuildContainerBackedRedisConfigForTest(t *testing.T) (config *Config, shutd } func BuildContainerBackedRedisConfig(ctx context.Context) (config *Config, shutdownFunc func(context.Context) error, err error) { - redisContainer, err := rediscontainers.Run( - ctx, - redisContainerImageToUse, - rediscontainers.WithLogLevel(rediscontainers.LogLevelNotice), - testcontainers.WithWaitStrategyAndDeadline(30*time.Second, wait.ForListeningPort("6379/tcp")), - ) + redisContainer, err := containers.StartWithRetry(ctx, func(ctx context.Context) (*rediscontainers.RedisContainer, error) { + return rediscontainers.Run( + ctx, + redisContainerImageToUse, + rediscontainers.WithLogLevel(rediscontainers.LogLevelNotice), + testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForAll( + wait.ForListeningPort("6379/tcp"), + wait.ForLog("Ready to accept connections"), + )), + ) + }) if err != nil { return nil, nil, fmt.Errorf("failed to build redis container: %w", err) } diff --git a/search/text/elasticsearch/elasticsearch_test.go b/search/text/elasticsearch/elasticsearch_test.go index 958e2e0..fcd27d2 100644 --- a/search/text/elasticsearch/elasticsearch_test.go +++ b/search/text/elasticsearch/elasticsearch_test.go @@ -17,6 +17,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/identifiers" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/logging" "github.com/verygoodsoftwarenotvirus/platform/v5/observability/tracing" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" "github.com/shoenig/test" "github.com/shoenig/test/must" @@ -25,6 +26,8 @@ import ( "github.com/testcontainers/testcontainers-go/wait" ) +const elasticsearchImage = "elasticsearch:8.10.2" + var runningContainerTests = strings.ToLower(os.Getenv("RUN_CONTAINER_TESTS")) == "true" // esTestInfra holds a single shared Elasticsearch container for all container- @@ -37,15 +40,27 @@ type esTestInfra struct { } // extendWaitStrategyTimeout returns a PostCreates lifecycle hook that extends -// the timeouts of the elasticsearch module's bundled wait strategies. The -// module hard-codes a 60s deadline on its MultiStrategy and each inner -// FileStrategy/HTTPStrategy also defaults to 60s, which is too tight on a cold -// start (image pull + ES auto-config + cert generation) on a busy host. We -// have to mutate the strategies after creation because the module appends its -// configureWaitFor customizer after any user opts, leaving no other override -// hook. Failure to assert the expected types is loud rather than silent so a -// future testcontainers refactor surfaces immediately instead of regressing to -// a flaky 60s timeout. +// the timeouts of the elasticsearch module's bundled wait strategies. +// +// Why this exists: the elasticsearch module appends its own configureWaitFor +// customizer AFTER user opts, via WithAdditionalWaitStrategyAndDeadline(60s, ...), +// which unconditionally clamps the outer MultiStrategy deadline to 60s. The +// inner HTTPStrategy and FileStrategy also each default to 60s, and +// HTTPStrategy.WaitUntilReady wraps its ctx in context.WithTimeout(60s), so +// extending only the outer deadline is insufficient — each inner strategy +// must be extended individually. Neither WithStartupTimeoutDefault on +// MultiStrategy nor passing WithWaitStrategyAndDeadline as a user opt works +// around this; both get overwritten by the module's append. +// +// A cold start (image pull + ES auto-config + cert generation) regularly +// exceeds 60s on a busy CI host, so 60s is too tight in practice. +// +// The type assertions are load-bearing: we have to touch concrete types +// (*wait.MultiStrategy and the inner *wait.HTTPStrategy / *wait.FileStrategy) +// because wait.Strategy has no interface method for mutating a timeout. A +// future testcontainers refactor that changes these types will fail loudly +// here rather than silently regressing to a flaky 60s ceiling — which is the +// right failure mode. func extendWaitStrategyTimeout(timeout time.Duration) testcontainers.ContainerHook { return func(_ context.Context, c testcontainers.Container) error { dc, ok := c.(*testcontainers.DockerContainer) @@ -72,16 +87,18 @@ func extendWaitStrategyTimeout(timeout time.Duration) testcontainers.ContainerHo func buildEsTestInfra(t *testing.T) *esTestInfra { t.Helper() - elasticsearchContainer, err := elasticsearchcontainers.Run( - t.Context(), - "elasticsearch:8.10.2", - elasticsearchcontainers.WithPassword("arbitraryPassword"), - testcontainers.WithAdditionalLifecycleHooks(testcontainers.ContainerLifecycleHooks{ - PostCreates: []testcontainers.ContainerHook{ - extendWaitStrategyTimeout(5 * time.Minute), - }, - }), - ) + elasticsearchContainer, err := containers.StartWithRetry(t.Context(), func(ctx context.Context) (*elasticsearchcontainers.ElasticsearchContainer, error) { + return elasticsearchcontainers.Run( + ctx, + elasticsearchImage, + elasticsearchcontainers.WithPassword("arbitraryPassword"), + testcontainers.WithAdditionalLifecycleHooks(testcontainers.ContainerLifecycleHooks{ + PostCreates: []testcontainers.ContainerHook{ + extendWaitStrategyTimeout(5 * time.Minute), + }, + }), + ) + }) must.NoError(t, err) must.NotNil(t, elasticsearchContainer) diff --git a/search/vector/pgvector/pgvector_test.go b/search/vector/pgvector/pgvector_test.go index 8ad7393..299157d 100644 --- a/search/vector/pgvector/pgvector_test.go +++ b/search/vector/pgvector/pgvector_test.go @@ -15,6 +15,7 @@ import ( "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics" mockmetrics "github.com/verygoodsoftwarenotvirus/platform/v5/observability/metrics/mock" vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" _ "github.com/jackc/pgx/v5/stdlib" "github.com/shoenig/test" @@ -72,14 +73,16 @@ func buildContainerBackedPgvector(t *testing.T) (client *testDBClient, shutdown t.Helper() ctx := t.Context() - container, err := postgrescontainer.Run( - ctx, - pgvectorImage, - postgrescontainer.WithDatabase("vectortest"), - postgrescontainer.WithUsername("vectortest"), - postgrescontainer.WithPassword("vectortest"), - testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForLog("database system is ready to accept connections").WithOccurrence(2)), - ) + container, err := containers.StartWithRetry(ctx, func(ctx context.Context) (*postgrescontainer.PostgresContainer, error) { + return postgrescontainer.Run( + ctx, + pgvectorImage, + postgrescontainer.WithDatabase("vectortest"), + postgrescontainer.WithUsername("vectortest"), + postgrescontainer.WithPassword("vectortest"), + testcontainers.WithWaitStrategyAndDeadline(2*time.Minute, wait.ForLog("database system is ready to accept connections").WithOccurrence(2)), + ) + }) must.NoError(t, err) must.NotNil(t, container) diff --git a/search/vector/qdrant/qdrant_test.go b/search/vector/qdrant/qdrant_test.go index 338ef34..3ced934 100644 --- a/search/vector/qdrant/qdrant_test.go +++ b/search/vector/qdrant/qdrant_test.go @@ -18,6 +18,7 @@ import ( platformerrors "github.com/verygoodsoftwarenotvirus/platform/v5/errors" "github.com/verygoodsoftwarenotvirus/platform/v5/identifiers" vectorsearch "github.com/verygoodsoftwarenotvirus/platform/v5/search/vector" + "github.com/verygoodsoftwarenotvirus/platform/v5/testutils/containers" "github.com/shoenig/test" "github.com/shoenig/test/must" @@ -967,7 +968,9 @@ func buildContainerBackedQdrant(t *testing.T) (cfg *Config, shutdown func(contex }, Started: true, } - container, err := testcontainers.GenericContainer(ctx, req) + container, err := containers.StartWithRetry(ctx, func(ctx context.Context) (testcontainers.Container, error) { + return testcontainers.GenericContainer(ctx, req) + }) must.NoError(t, err) must.NotNil(t, container) diff --git a/testutils/containers/containers.go b/testutils/containers/containers.go new file mode 100644 index 0000000..e1c0010 --- /dev/null +++ b/testutils/containers/containers.go @@ -0,0 +1,51 @@ +// Package containers provides shared helpers for starting testcontainers +// with uniform retry behavior. It exists so every container builder in the +// repo can opt into the same backoff policy instead of each rolling its own. +// +// Container startup flakes for many non-deterministic reasons — Docker daemon +// cold starts, port conflicts, image pull stalls, transient network blips — +// and a single attempt is too brittle for a large integration test suite. +package containers + +import ( + "context" + "time" + + "github.com/verygoodsoftwarenotvirus/platform/v5/retry" +) + +const ( + defaultMaxAttempts = 5 + defaultInitialDelay = time.Second +) + +// DefaultRetryConfig returns the retry.Config used by StartWithRetry. Callers +// that need bespoke retry behavior can start from this and tweak individual +// fields before calling retry.NewExponentialBackoffPolicy themselves. +func DefaultRetryConfig() retry.Config { + return retry.Config{ + MaxAttempts: defaultMaxAttempts, + InitialDelay: defaultInitialDelay, + UseJitter: false, + } +} + +// StartWithRetry invokes start with exponential backoff retry on failure. It +// is a thin wrapper over the retry package so that every container builder in +// the repo gets the same backoff policy for free. +// +// The callback receives the same ctx that was passed in, and is expected to +// return the concrete container type from its module's Run function (e.g. +// *postgres.PostgresContainer, *redis.RedisContainer). Callers handle the +// error themselves — typically via must.NoError(t, err) — so that this helper +// stays decoupled from the testing package. +func StartWithRetry[C any](ctx context.Context, start func(context.Context) (C, error)) (C, error) { + var container C + policy := retry.NewExponentialBackoffPolicy(DefaultRetryConfig()) + err := policy.Execute(ctx, func(ctx context.Context) error { + var startErr error + container, startErr = start(ctx) + return startErr + }) + return container, err +} diff --git a/testutils/containers/containers_test.go b/testutils/containers/containers_test.go new file mode 100644 index 0000000..976cd83 --- /dev/null +++ b/testutils/containers/containers_test.go @@ -0,0 +1,88 @@ +package containers + +import ( + "context" + "errors" + "testing" + + "github.com/shoenig/test" + "github.com/shoenig/test/must" +) + +type fakeContainer struct { + id int +} + +func TestDefaultRetryConfig(T *testing.T) { + T.Parallel() + + cfg := DefaultRetryConfig() + test.EqOp(T, uint(defaultMaxAttempts), cfg.MaxAttempts) + test.EqOp(T, defaultInitialDelay, cfg.InitialDelay) + test.False(T, cfg.UseJitter) +} + +func TestStartWithRetry(T *testing.T) { + T.Parallel() + + T.Run("succeeds on first attempt", func(t *testing.T) { + t.Parallel() + + var calls int + got, err := StartWithRetry(t.Context(), func(_ context.Context) (*fakeContainer, error) { + calls++ + return &fakeContainer{id: 1}, nil + }) + must.NoError(t, err) + must.NotNil(t, got) + test.EqOp(t, 1, got.id) + test.EqOp(t, 1, calls) + }) + + T.Run("retries transient failures then succeeds", func(t *testing.T) { + t.Parallel() + + var calls int + got, err := StartWithRetry(t.Context(), func(_ context.Context) (*fakeContainer, error) { + calls++ + if calls < 3 { + return nil, errors.New("flaky docker") + } + return &fakeContainer{id: calls}, nil + }) + must.NoError(t, err) + must.NotNil(t, got) + test.EqOp(t, 3, calls) + test.EqOp(t, 3, got.id) + }) + + T.Run("gives up after MaxAttempts and returns last error", func(t *testing.T) { + t.Parallel() + + var calls int + boom := errors.New("always broken") + got, err := StartWithRetry(t.Context(), func(_ context.Context) (*fakeContainer, error) { + calls++ + return nil, boom + }) + must.ErrorIs(t, err, boom) + must.Nil(t, got) + test.EqOp(t, defaultMaxAttempts, calls) + }) + + T.Run("aborts when context is cancelled", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + var calls int + _, err := StartWithRetry(ctx, func(_ context.Context) (*fakeContainer, error) { + calls++ + return nil, errors.New("never reached") + }) + must.Error(t, err) + // retry policy exits before invoking the callback when ctx is already done. + test.EqOp(t, 0, calls) + }) +}