diff --git a/internal/controller/tcp/cira/handler.go b/internal/controller/tcp/cira/handler.go index 8910270e0..02f68dd75 100644 --- a/internal/controller/tcp/cira/handler.go +++ b/internal/controller/tcp/cira/handler.go @@ -130,6 +130,7 @@ func (h *APFHandler) OnGlobalRequest(request apf.GlobalRequest) bool { } // ShouldSendKeepAlive returns whether keep-alive should be sent based on global request count. +// Returns true only once, when the threshold is exactly reached, to avoid sending duplicate requests. func (h *APFHandler) ShouldSendKeepAlive() bool { - return h.globalRequestCount >= globalRequestThreshold + return h.globalRequestCount == globalRequestThreshold } diff --git a/internal/controller/tcp/cira/tunnel.go b/internal/controller/tcp/cira/tunnel.go index 97bd4cd9a..9b536dca6 100644 --- a/internal/controller/tcp/cira/tunnel.go +++ b/internal/controller/tcp/cira/tunnel.go @@ -3,6 +3,7 @@ package cira import ( "bytes" + "context" "crypto/tls" "encoding/binary" "encoding/hex" @@ -122,6 +123,7 @@ type connectionContext struct { session *apf.Session authenticated bool device *wsman.ConnectionEntry + devices devices.Feature log logger.Interface } @@ -141,6 +143,7 @@ func (s *Server) handleConnection(conn net.Conn) { conn: conn, tlsConn: tlsConn, handler: NewAPFHandler(s.devices, s.log), + devices: s.devices, session: &apf.Session{ Timer: time.NewTimer(apfSessionTimeout), }, @@ -156,6 +159,10 @@ func (s *Server) handleConnection(conn net.Conn) { func (ctx *connectionContext) cleanup() { deviceID := ctx.handler.DeviceID() if ctx.authenticated && deviceID != "" { + if err := ctx.devices.UpdateConnectionStatus(context.Background(), deviceID, false); err != nil { + ctx.log.Error("Failed to update disconnection status for device %s: %v", deviceID, err) + } + wsman.RemoveConnection(deviceID) } @@ -211,12 +218,28 @@ func (ctx *connectionContext) processNextMessage() (shouldReturn bool) { } if err := ctx.sendKeepAliveIfNeeded(messageType); err != nil { + ctx.log.Error("Keep-alive failed for device %s: %v", ctx.handler.DeviceID(), err) + return true } + ctx.updateLastSeenIfKeepAlive(messageType) + return false } +func (ctx *connectionContext) updateLastSeenIfKeepAlive(messageType byte) { + if !ctx.authenticated || messageType != apf.APF_KEEPALIVE_REQUEST { + return + } + + deviceID := ctx.handler.DeviceID() + + if err := ctx.devices.UpdateLastSeen(context.Background(), deviceID); err != nil { + ctx.log.Error("Failed to update last seen for device %s: %v", deviceID, err) + } +} + func (ctx *connectionContext) readData() ([]byte, error) { buf := make([]byte, readBufferSize) @@ -279,6 +302,10 @@ func (ctx *connectionContext) registerDevice() { wsman.SetConnectionEntry(deviceID, ctx.device) + if err := ctx.devices.UpdateConnectionStatus(context.Background(), deviceID, true); err != nil { + ctx.log.Error("Failed to update connection status for device %s: %v", deviceID, err) + } + ctx.log.Info("Device authenticated and registered: %s", deviceID) } diff --git a/internal/controller/tcp/cira/tunnel_test.go b/internal/controller/tcp/cira/tunnel_test.go index e746693e1..6d299e47e 100644 --- a/internal/controller/tcp/cira/tunnel_test.go +++ b/internal/controller/tcp/cira/tunnel_test.go @@ -1,14 +1,19 @@ package cira import ( + "errors" + "net" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "github.com/device-management-toolkit/go-wsman-messages/v2/pkg/apf" + "github.com/device-management-toolkit/console/internal/mocks" + "github.com/device-management-toolkit/console/internal/usecase/devices" "github.com/device-management-toolkit/console/internal/usecase/devices/wsman" "github.com/device-management-toolkit/console/pkg/logger" ) @@ -96,6 +101,35 @@ var cleanupTests = []cleanupTestCase{ }, } +func TestConnectionContext_cleanup_UpdateConnectionStatusError(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockDevices := mocks.NewMockDeviceManagementFeature(ctrl) + mockDevices.EXPECT().UpdateConnectionStatus(gomock.Any(), "test-device", false).Return(errors.New("db error")) + + log := logger.New("error") + handler := NewAPFHandler(mockDevices, log) + handler.deviceID = "test-device" + + session := &apf.Session{Timer: time.NewTimer(10 * time.Second)} + ctx := &connectionContext{ + session: session, + authenticated: true, + handler: handler, + devices: mockDevices, + log: log, + } + + wsman.SetConnectionEntry("test-device", &wsman.ConnectionEntry{}) + + require.NotPanics(t, func() { + ctx.cleanup() + }) + + assert.Nil(t, wsman.GetConnectionEntry("test-device"), "Connection should still be removed even on status update failure") +} + func TestConnectionContext_cleanup(t *testing.T) { t.Parallel() @@ -112,9 +146,19 @@ func TestConnectionContext_cleanup(t *testing.T) { func runCleanupTest(t *testing.T, tt cleanupTestCase) { t.Helper() + ctrl := gomock.NewController(t) + mockDevices := mocks.NewMockDeviceManagementFeature(ctrl) + + var devicesFeature devices.Feature + + if tt.authenticated && tt.deviceID != "" { + mockDevices.EXPECT().UpdateConnectionStatus(gomock.Any(), tt.deviceID, false).Return(nil) + devicesFeature = mockDevices + } + // Setup session := tt.setupSession() - ctx := setupConnectionContext(t, session, tt.authenticated, tt.deviceID) + ctx := setupConnectionContext(t, session, tt.authenticated, tt.deviceID, devicesFeature) setupConnectionsMap(t, tt.authenticated, tt.deviceID) @@ -128,18 +172,19 @@ func runCleanupTest(t *testing.T, tt cleanupTestCase) { verifyConnectionRemoved(t, tt.authenticated, tt.deviceID) } -func setupConnectionContext(t *testing.T, session *apf.Session, authenticated bool, deviceID string) *connectionContext { +func setupConnectionContext(t *testing.T, session *apf.Session, authenticated bool, deviceID string, devicesFeature devices.Feature) *connectionContext { t.Helper() // Create a proper APFHandler with mock deviceID log := logger.New("error") - handler := NewAPFHandler(nil, log) // devices.Feature can be nil for cleanup test - handler.deviceID = deviceID // Set deviceID directly for test + handler := NewAPFHandler(devicesFeature, log) + handler.deviceID = deviceID return &connectionContext{ session: session, authenticated: authenticated, handler: handler, + devices: devicesFeature, } } @@ -173,3 +218,65 @@ func verifyConnectionRemoved(t *testing.T, authenticated bool, deviceID string) assert.False(t, exists, "Connection should be removed from map") } } + +// fakeConn is a minimal net.Conn implementation for tests. +type fakeConn struct{ net.Conn } + +func TestConnectionContext_registerDevice(t *testing.T) { + t.Parallel() + + t.Run("successful registration updates connection status", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockDevices := mocks.NewMockDeviceManagementFeature(ctrl) + mockDevices.EXPECT().UpdateConnectionStatus(gomock.Any(), "dev-123", true).Return(nil) + + log := logger.New("error") + handler := NewAPFHandler(mockDevices, log) + handler.deviceID = "dev-123" + + ctx := &connectionContext{ + conn: &fakeConn{}, + handler: handler, + devices: mockDevices, + log: log, + } + + ctx.registerDevice() + + assert.True(t, ctx.authenticated) + assert.NotNil(t, ctx.device) + assert.NotNil(t, wsman.GetConnectionEntry("dev-123")) + + t.Cleanup(func() { wsman.RemoveConnection("dev-123") }) + }) + + t.Run("registration continues when UpdateConnectionStatus fails", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + + mockDevices := mocks.NewMockDeviceManagementFeature(ctrl) + mockDevices.EXPECT().UpdateConnectionStatus(gomock.Any(), "dev-456", true).Return(errors.New("db error")) + + log := logger.New("error") + handler := NewAPFHandler(mockDevices, log) + handler.deviceID = "dev-456" + + ctx := &connectionContext{ + conn: &fakeConn{}, + handler: handler, + devices: mockDevices, + log: log, + } + + ctx.registerDevice() + + assert.True(t, ctx.authenticated) + assert.NotNil(t, wsman.GetConnectionEntry("dev-456")) + + t.Cleanup(func() { wsman.RemoveConnection("dev-456") }) + }) +} diff --git a/internal/controller/ws/v1/interface.go b/internal/controller/ws/v1/interface.go index 7ea87f723..abcebe431 100644 --- a/internal/controller/ws/v1/interface.go +++ b/internal/controller/ws/v1/interface.go @@ -30,6 +30,8 @@ type Feature interface { GetCount(context.Context, string) (int, error) Get(ctx context.Context, top, skip int, tenantID string) ([]dto.Device, error) GetByID(ctx context.Context, guid, tenantID string, includeSecrets bool) (*dto.Device, error) + UpdateConnectionStatus(ctx context.Context, guid string, status bool) error + UpdateLastSeen(ctx context.Context, guid string) error GetDistinctTags(ctx context.Context, tenantID string) ([]string, error) GetByTags(ctx context.Context, tags, method string, limit, offset int, tenantID string) ([]dto.Device, error) Delete(ctx context.Context, guid, tenantID string) error diff --git a/internal/mocks/devicemanagement_mocks.go b/internal/mocks/devicemanagement_mocks.go index ee3694979..11a564515 100644 --- a/internal/mocks/devicemanagement_mocks.go +++ b/internal/mocks/devicemanagement_mocks.go @@ -410,6 +410,34 @@ func (mr *MockDeviceManagementRepositoryMockRecorder) Update(ctx, d any) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDeviceManagementRepository)(nil).Update), ctx, d) } +// UpdateConnectionStatus mocks base method. +func (m *MockDeviceManagementRepository) UpdateConnectionStatus(ctx context.Context, guid string, status bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateConnectionStatus", ctx, guid, status) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateConnectionStatus indicates an expected call of UpdateConnectionStatus. +func (mr *MockDeviceManagementRepositoryMockRecorder) UpdateConnectionStatus(ctx, guid, status any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateConnectionStatus", reflect.TypeOf((*MockDeviceManagementRepository)(nil).UpdateConnectionStatus), ctx, guid, status) +} + +// UpdateLastSeen mocks base method. +func (m *MockDeviceManagementRepository) UpdateLastSeen(ctx context.Context, guid string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateLastSeen", ctx, guid) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateLastSeen indicates an expected call of UpdateLastSeen. +func (mr *MockDeviceManagementRepositoryMockRecorder) UpdateLastSeen(ctx, guid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLastSeen", reflect.TypeOf((*MockDeviceManagementRepository)(nil).UpdateLastSeen), ctx, guid) +} + // MockDeviceManagementFeature is a mock of Feature interface. type MockDeviceManagementFeature struct { ctrl *gomock.Controller @@ -1002,3 +1030,31 @@ func (mr *MockDeviceManagementFeatureMockRecorder) Update(ctx, d any) *gomock.Ca mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDeviceManagementFeature)(nil).Update), ctx, d) } + +// UpdateConnectionStatus mocks base method. +func (m *MockDeviceManagementFeature) UpdateConnectionStatus(ctx context.Context, guid string, status bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateConnectionStatus", ctx, guid, status) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateConnectionStatus indicates an expected call of UpdateConnectionStatus. +func (mr *MockDeviceManagementFeatureMockRecorder) UpdateConnectionStatus(ctx, guid, status any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateConnectionStatus", reflect.TypeOf((*MockDeviceManagementFeature)(nil).UpdateConnectionStatus), ctx, guid, status) +} + +// UpdateLastSeen mocks base method. +func (m *MockDeviceManagementFeature) UpdateLastSeen(ctx context.Context, guid string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateLastSeen", ctx, guid) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateLastSeen indicates an expected call of UpdateLastSeen. +func (mr *MockDeviceManagementFeatureMockRecorder) UpdateLastSeen(ctx, guid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLastSeen", reflect.TypeOf((*MockDeviceManagementFeature)(nil).UpdateLastSeen), ctx, guid) +} diff --git a/internal/mocks/wsv1_mocks.go b/internal/mocks/wsv1_mocks.go index 7ab304aa4..3ca7c33e3 100644 --- a/internal/mocks/wsv1_mocks.go +++ b/internal/mocks/wsv1_mocks.go @@ -691,3 +691,31 @@ func (mr *MockFeatureMockRecorder) Update(ctx, d any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockFeature)(nil).Update), ctx, d) } + +// UpdateConnectionStatus mocks base method. +func (m *MockFeature) UpdateConnectionStatus(ctx context.Context, guid string, status bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateConnectionStatus", ctx, guid, status) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateConnectionStatus indicates an expected call of UpdateConnectionStatus. +func (mr *MockFeatureMockRecorder) UpdateConnectionStatus(ctx, guid, status any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateConnectionStatus", reflect.TypeOf((*MockFeature)(nil).UpdateConnectionStatus), ctx, guid, status) +} + +// UpdateLastSeen mocks base method. +func (m *MockFeature) UpdateLastSeen(ctx context.Context, guid string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateLastSeen", ctx, guid) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateLastSeen indicates an expected call of UpdateLastSeen. +func (mr *MockFeatureMockRecorder) UpdateLastSeen(ctx, guid any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateLastSeen", reflect.TypeOf((*MockFeature)(nil).UpdateLastSeen), ctx, guid) +} diff --git a/internal/usecase/devices/interfaces.go b/internal/usecase/devices/interfaces.go index f99c07462..1fe76e625 100644 --- a/internal/usecase/devices/interfaces.go +++ b/internal/usecase/devices/interfaces.go @@ -44,12 +44,16 @@ type ( Update(ctx context.Context, d *entity.Device) (bool, error) Insert(ctx context.Context, d *entity.Device) (string, error) GetByColumn(ctx context.Context, columnName, queryValue, tenantID string) ([]entity.Device, error) + UpdateConnectionStatus(ctx context.Context, guid string, status bool) error + UpdateLastSeen(ctx context.Context, guid string) error } Feature interface { // Repository/Database Calls GetCount(context.Context, string) (int, error) Get(ctx context.Context, top, skip int, tenantID string) ([]dto.Device, error) GetByID(ctx context.Context, guid, tenantID string, includeSecrets bool) (*dto.Device, error) + UpdateConnectionStatus(ctx context.Context, guid string, status bool) error + UpdateLastSeen(ctx context.Context, guid string) error GetDistinctTags(ctx context.Context, tenantID string) ([]string, error) GetByTags(ctx context.Context, tags, method string, limit, offset int, tenantID string) ([]dto.Device, error) Delete(ctx context.Context, guid, tenantID string) error diff --git a/internal/usecase/devices/repo.go b/internal/usecase/devices/repo.go index e464fd13b..c0844314c 100644 --- a/internal/usecase/devices/repo.go +++ b/internal/usecase/devices/repo.go @@ -145,6 +145,24 @@ func (uc *UseCase) GetByTags(ctx context.Context, tags, method string, limit, of return d1, nil } +func (uc *UseCase) UpdateConnectionStatus(ctx context.Context, guid string, status bool) error { + err := uc.repo.UpdateConnectionStatus(ctx, strings.ToLower(guid), status) + if err != nil { + return ErrDatabase.Wrap("UpdateConnectionStatus", "uc.repo.UpdateConnectionStatus", err) + } + + return nil +} + +func (uc *UseCase) UpdateLastSeen(ctx context.Context, guid string) error { + err := uc.repo.UpdateLastSeen(ctx, strings.ToLower(guid)) + if err != nil { + return ErrDatabase.Wrap("UpdateLastSeen", "uc.repo.UpdateLastSeen", err) + } + + return nil +} + func (uc *UseCase) Delete(ctx context.Context, guid, tenantID string) error { isSuccessful, err := uc.repo.Delete(ctx, strings.ToLower(guid), tenantID) if err != nil { diff --git a/internal/usecase/devices/repo_test.go b/internal/usecase/devices/repo_test.go index e0c1e2afb..c29ba6de6 100644 --- a/internal/usecase/devices/repo_test.go +++ b/internal/usecase/devices/repo_test.go @@ -728,3 +728,132 @@ func TestUpdate_UUIDNormalization(t *testing.T) { require.Equal(t, "aaf0c395-c2a2-992e-5655-48210b50d8c9", result.GUID) }) } + +func TestUpdateConnectionStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + guid string + status bool + mock func(*mocks.MockDeviceManagementRepository) + err error + }{ + { + name: "successful connection status update - connected", + guid: "device-guid-123", + status: true, + mock: func(repo *mocks.MockDeviceManagementRepository) { + repo.EXPECT(). + UpdateConnectionStatus(context.Background(), "device-guid-123", true). + Return(nil) + }, + err: nil, + }, + { + name: "successful connection status update - disconnected", + guid: "device-guid-123", + status: false, + mock: func(repo *mocks.MockDeviceManagementRepository) { + repo.EXPECT(). + UpdateConnectionStatus(context.Background(), "device-guid-123", false). + Return(nil) + }, + err: nil, + }, + { + name: "mixed-case GUID is normalized to lowercase", + guid: "DEVICE-GUID-123", + status: true, + mock: func(repo *mocks.MockDeviceManagementRepository) { + repo.EXPECT(). + UpdateConnectionStatus(context.Background(), "device-guid-123", true). + Return(nil) + }, + err: nil, + }, + { + name: "database error", + guid: "device-guid-123", + status: true, + mock: func(repo *mocks.MockDeviceManagementRepository) { + repo.EXPECT(). + UpdateConnectionStatus(context.Background(), "device-guid-123", true). + Return(devices.ErrDatabase) + }, + err: devices.ErrDatabase, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + useCase, repo, _ := devicesTest(t) + + tc.mock(repo) + + err := useCase.UpdateConnectionStatus(context.Background(), tc.guid, tc.status) + + require.IsType(t, tc.err, err) + }) + } +} + +func TestUpdateLastSeen(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + guid string + mock func(*mocks.MockDeviceManagementRepository) + err error + }{ + { + name: "successful update", + guid: "device-guid-123", + mock: func(repo *mocks.MockDeviceManagementRepository) { + repo.EXPECT(). + UpdateLastSeen(context.Background(), "device-guid-123"). + Return(nil) + }, + err: nil, + }, + { + name: "mixed-case GUID is normalized to lowercase", + guid: "DEVICE-GUID-123", + mock: func(repo *mocks.MockDeviceManagementRepository) { + repo.EXPECT(). + UpdateLastSeen(context.Background(), "device-guid-123"). + Return(nil) + }, + err: nil, + }, + { + name: "database error", + guid: "device-guid-123", + mock: func(repo *mocks.MockDeviceManagementRepository) { + repo.EXPECT(). + UpdateLastSeen(context.Background(), "device-guid-123"). + Return(devices.ErrDatabase) + }, + err: devices.ErrDatabase, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + useCase, repo, _ := devicesTest(t) + + tc.mock(repo) + + err := useCase.UpdateLastSeen(context.Background(), tc.guid) + + require.IsType(t, tc.err, err) + }) + } +} diff --git a/internal/usecase/sqldb/device.go b/internal/usecase/sqldb/device.go index cd0f870c9..a3bf3f18a 100644 --- a/internal/usecase/sqldb/device.go +++ b/internal/usecase/sqldb/device.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "strings" + "time" "github.com/device-management-toolkit/console/internal/entity" "github.com/device-management-toolkit/console/pkg/consoleerrors" @@ -364,6 +365,55 @@ func (r *DeviceRepo) Update(_ context.Context, d *entity.Device) (bool, error) { return rowsAffected > 0, nil } +// UpdateConnectionStatus updates only the connection status and timestamps for a device. +func (r *DeviceRepo) UpdateConnectionStatus(_ context.Context, guid string, status bool) error { + now := time.Now().Format("2006-01-02 15:04:05") + + builder := r.Builder. + Update("devices"). + Set("connectionstatus", status). + Where("guid = ?", guid) + + if status { + builder = builder.Set("lastconnected", now) + } else { + builder = builder.Set("lastdisconnected", now) + } + + sqlQuery, args, err := builder.ToSql() + if err != nil { + return ErrDeviceDatabase.Wrap("UpdateConnectionStatus", "r.Builder", err) + } + + _, err = r.Pool.ExecContext(context.Background(), sqlQuery, args...) + if err != nil { + return ErrDeviceDatabase.Wrap("UpdateConnectionStatus", "r.Pool.Exec", err) + } + + return nil +} + +// UpdateLastSeen updates the lastseen timestamp for a device. +func (r *DeviceRepo) UpdateLastSeen(_ context.Context, guid string) error { + now := time.Now().Format("2006-01-02 15:04:05") + + sqlQuery, args, err := r.Builder. + Update("devices"). + Set("lastseen", now). + Where("guid = ?", guid). + ToSql() + if err != nil { + return ErrDeviceDatabase.Wrap("UpdateLastSeen", "r.Builder", err) + } + + _, err = r.Pool.ExecContext(context.Background(), sqlQuery, args...) + if err != nil { + return ErrDeviceDatabase.Wrap("UpdateLastSeen", "r.Pool.Exec", err) + } + + return nil +} + // Insert -. func (r *DeviceRepo) Insert(_ context.Context, d *entity.Device) (string, error) { insertBuilder := r.Builder. diff --git a/internal/usecase/sqldb/device_test.go b/internal/usecase/sqldb/device_test.go index 870117be5..1f31250ee 100644 --- a/internal/usecase/sqldb/device_test.go +++ b/internal/usecase/sqldb/device_test.go @@ -47,7 +47,10 @@ func setupDeviceTable(t *testing.T) *sql.DB { mebxpassword TEXT, usetls BOOLEAN NOT NULL DEFAULT FALSE, allowselfsigned BOOLEAN NOT NULL DEFAULT FALSE, - certhash TEXT NOT NULL DEFAULT '' + certhash TEXT NOT NULL DEFAULT '', + lastconnected TEXT, + lastdisconnected TEXT, + lastseen TEXT ); `) require.NoError(t, err) @@ -1094,3 +1097,196 @@ func TestDeviceRepo_GetByColumn(t *testing.T) { }) } } + +func TestDeviceRepo_UpdateConnectionStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(dbConn *sql.DB) + guid string + status bool + err error + verify func(t *testing.T, dbConn *sql.DB) + }{ + { + name: "Set connected status to true", + setup: func(dbConn *sql.DB) { + _, err := dbConn.ExecContext(context.Background(), + `INSERT INTO devices (guid, hostname, tags, mpsinstance, connectionstatus, mpsusername, tenantid, friendlyname, dnssuffix, deviceinfo, username, password, usetls, allowselfsigned) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "guid1", "hostname1", "tag1", "mps1", false, "mpsuser1", "tenant1", "friendly1", "dns1", "info1", "user1", "pass1", true, false) + require.NoError(t, err) + }, + guid: "guid1", + status: true, + err: nil, + verify: func(t *testing.T, dbConn *sql.DB) { + t.Helper() + + var ( + connStatus bool + lastConnected sql.NullString + ) + + err := dbConn.QueryRowContext(context.Background(), + "SELECT connectionstatus, lastconnected FROM devices WHERE guid = ?", "guid1"). + Scan(&connStatus, &lastConnected) + require.NoError(t, err) + assert.True(t, connStatus) + assert.True(t, lastConnected.Valid, "lastconnected should be set") + }, + }, + { + name: "Set connected status to false", + setup: func(dbConn *sql.DB) { + _, err := dbConn.ExecContext(context.Background(), + `INSERT INTO devices (guid, hostname, tags, mpsinstance, connectionstatus, mpsusername, tenantid, friendlyname, dnssuffix, deviceinfo, username, password, usetls, allowselfsigned) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "guid2", "hostname2", "tag2", "mps2", true, "mpsuser2", "tenant1", "friendly2", "dns2", "info2", "user2", "pass2", true, false) + require.NoError(t, err) + }, + guid: "guid2", + status: false, + err: nil, + verify: func(t *testing.T, dbConn *sql.DB) { + t.Helper() + + var ( + connStatus bool + lastDisconnected sql.NullString + ) + + err := dbConn.QueryRowContext(context.Background(), + "SELECT connectionstatus, lastdisconnected FROM devices WHERE guid = ?", "guid2"). + Scan(&connStatus, &lastDisconnected) + require.NoError(t, err) + assert.False(t, connStatus) + assert.True(t, lastDisconnected.Valid, "lastdisconnected should be set") + }, + }, + { + name: "Update non-existent device - no error", + setup: func(_ *sql.DB) {}, + guid: "nonexistent", + status: true, + err: nil, + verify: func(_ *testing.T, _ *sql.DB) {}, + }, + { + name: QueryExecutionErrorTestName, + setup: func(_ *sql.DB) {}, + guid: "guid1", + status: true, + err: &sqldb.DatabaseError{}, + verify: func(_ *testing.T, _ *sql.DB) {}, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + dbConn := setupDeviceTable(t) + defer dbConn.Close() + + tc.setup(dbConn) + + sqlConfig := CreateSQLConfig(dbConn, tc.name == QueryExecutionErrorTestName) + + mockLog := mocks.NewMockLogger(nil) + repo := sqldb.NewDeviceRepo(sqlConfig, mockLog) + + err := repo.UpdateConnectionStatus(context.Background(), tc.guid, tc.status) + + if tc.err == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + + var dbErr sqldb.DatabaseError + assert.True(t, errors.As(err, &dbErr)) + } + + tc.verify(t, dbConn) + }) + } +} + +func TestDeviceRepo_UpdateLastSeen(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(dbConn *sql.DB) + guid string + err error + verify func(t *testing.T, dbConn *sql.DB) + }{ + { + name: "Successfully updates lastseen", + setup: func(dbConn *sql.DB) { + _, err := dbConn.ExecContext(context.Background(), + `INSERT INTO devices (guid, hostname, tags, mpsinstance, connectionstatus, mpsusername, tenantid, friendlyname, dnssuffix, deviceinfo, username, password, usetls, allowselfsigned) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + "guid1", "hostname1", "tag1", "mps1", false, "mpsuser1", "tenant1", "friendly1", "dns1", "info1", "user1", "pass1", true, false) + require.NoError(t, err) + }, + guid: "guid1", + err: nil, + verify: func(t *testing.T, dbConn *sql.DB) { + t.Helper() + + var lastSeen sql.NullString + + err := dbConn.QueryRowContext(context.Background(), + "SELECT lastseen FROM devices WHERE guid = ?", "guid1"). + Scan(&lastSeen) + require.NoError(t, err) + assert.True(t, lastSeen.Valid, "lastseen should be set") + }, + }, + { + name: "Update non-existent device - no error", + setup: func(_ *sql.DB) {}, + guid: "nonexistent", + err: nil, + verify: func(_ *testing.T, _ *sql.DB) {}, + }, + { + name: QueryExecutionErrorTestName, + setup: func(_ *sql.DB) {}, + guid: "guid1", + err: &sqldb.DatabaseError{}, + verify: func(_ *testing.T, _ *sql.DB) {}, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + dbConn := setupDeviceTable(t) + defer dbConn.Close() + + tc.setup(dbConn) + + sqlConfig := CreateSQLConfig(dbConn, tc.name == QueryExecutionErrorTestName) + + mockLog := mocks.NewMockLogger(nil) + repo := sqldb.NewDeviceRepo(sqlConfig, mockLog) + + err := repo.UpdateLastSeen(context.Background(), tc.guid) + + if tc.err == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + + var dbErr sqldb.DatabaseError + assert.True(t, errors.As(err, &dbErr)) + } + + tc.verify(t, dbConn) + }) + } +}