diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 07c32120..b8f4e127 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -14,6 +14,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/Infisical/infisical-merge/packages/api" @@ -86,8 +87,9 @@ type GatewayConfig struct { } type pamSessionEntry struct { - cancel context.CancelFunc - conn *tls.Conn + cancel context.CancelFunc + conn *tls.Conn + lastActivity atomic.Int64 } type Gateway struct { @@ -166,20 +168,33 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { } // RegisterPAMSession registers an active PAM proxy connection for cancellation support -func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) { +// Returns a function that handlers should call when data flows through the connection +func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) func() { + entry := &pamSessionEntry{cancel: cancel, conn: conn} + entry.lastActivity.Store(time.Now().Unix()) + g.pamSessionsMu.Lock() defer g.pamSessionsMu.Unlock() - g.pamSessions[sessionID] = append(g.pamSessions[sessionID], &pamSessionEntry{cancel: cancel, conn: conn}) + g.pamSessions[sessionID] = append(g.pamSessions[sessionID], entry) + + return func() { + entry.lastActivity.Store(time.Now().Unix()) + } } // DeregisterPAMSession removes a specific connection from the session registry. +// Returns true if this was the last connection for the session. // The MongoDB proxy (if any) is NOT closed here — it persists across connections // so that subsequent client connections (e.g. mongosh retries) find a warm topology. // The proxy is cleaned up on session cancellation or gateway shutdown. -func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) { +func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) bool { g.pamSessionsMu.Lock() defer g.pamSessionsMu.Unlock() - entries := g.pamSessions[sessionID] + + entries, exists := g.pamSessions[sessionID] + if !exists { + return false + } for i, e := range entries { if e.conn == conn { g.pamSessions[sessionID] = append(entries[:i], entries[i+1:]...) @@ -188,7 +203,9 @@ func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) { } if len(g.pamSessions[sessionID]) == 0 { delete(g.pamSessions, sessionID) + return true } + return false } // CancelPAMSession kills all active connections for a PAM session @@ -264,6 +281,51 @@ func (g *Gateway) closeMongoProxy(sessionID string) { } } +const pamIdleTimeout = 30 * time.Minute + +// startIdleReaper periodically scans the PAM session registry and cancels +// sessions whose connections have had no data flow for pamIdleTimeout +func (g *Gateway) startIdleReaper(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + g.reapIdleSessions() + } + } +} + +func (g *Gateway) reapIdleSessions() { + cutoff := time.Now().Add(-pamIdleTimeout).Unix() + + g.pamSessionsMu.Lock() + var stale []string + for sessionID, entries := range g.pamSessions { + allIdle := true + for _, e := range entries { + if e.lastActivity.Load() > cutoff { + allIdle = false + break + } + } + if allIdle { + stale = append(stale, sessionID) + } + } + g.pamSessionsMu.Unlock() + + for _, sessionID := range stale { + log.Info().Str("sessionId", sessionID).Dur("idleTimeout", pamIdleTimeout).Msg("Reaping idle PAM session") + g.CancelPAMSession(sessionID) + if err := g.pamSessionUploader.CleanupPAMSession(sessionID, "idle_timeout"); err != nil { + log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to cleanup reaped PAM session") + } + } +} + func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { sendHeartbeat := func() error { if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { @@ -329,6 +391,8 @@ func (g *Gateway) Start(ctx context.Context) error { // Start session uploader goroutine for PAM g.pamSessionUploader.Start() + go g.startIdleReaper(ctx) + go func() { for { select { @@ -489,6 +553,25 @@ func (g *Gateway) handleConnection(client *ssh.Client) error { client.Close() }() + // Keepalive on the relay SSH connection. If the relay drops silently, + // this closes the client so the reconnect loop in connectWithRetry kicks in + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := util.SSHKeepalive(client, 15*time.Second); err != nil { + log.Warn().Err(err).Msg("Relay SSH keepalive failed, closing connection") + client.Close() + return + } + case <-g.ctx.Done(): + return + } + } + }() + // Process incoming channels with context cancellation support for { select { @@ -751,8 +834,8 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { return } else if forwardConfig.Mode == ForwardModePAM { sessionCtx, sessionCancel := context.WithCancel(g.ctx) - g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn) - defer g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn) + touchSession := g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn) + forwardConfig.PAMConfig.OnActivity = touchSession if err := pam.HandlePAMProxy(sessionCtx, tlsConn, &forwardConfig.PAMConfig, g.httpClient); err != nil { if err.Error() == "unexpected EOF" { log.Debug().Err(err).Msg("PAM proxy handler ended with unexpected connection termination") @@ -760,6 +843,14 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Error().Err(err).Msg("PAM proxy handler ended with error") } } + sessionCancel() + if lastConn := g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn); lastConn { + if err := forwardConfig.PAMConfig.SessionUploader.CleanupPAMSession( + forwardConfig.PAMConfig.SessionId, "connection_closed", + ); err != nil { + log.Error().Err(err).Str("sessionId", forwardConfig.PAMConfig.SessionId).Msg("Failed to cleanup PAM session") + } + } return } else if forwardConfig.Mode == ForwardModePAMCancellation { if err := pam.HandlePAMCancellation(g.ctx, tlsConn, &forwardConfig.PAMConfig, g.httpClient, g.CancelPAMSession); err != nil { diff --git a/packages/pam/handlers/ssh/proxy.go b/packages/pam/handlers/ssh/proxy.go index 0d2657af..8d9a6195 100644 --- a/packages/pam/handlers/ssh/proxy.go +++ b/packages/pam/handlers/ssh/proxy.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/Infisical/infisical-merge/packages/util" "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" ) @@ -26,6 +27,7 @@ type SSHProxyConfig struct { SessionID string SessionLogger session.SessionLogger BlockedCommandPatterns []*regexp.Regexp // Regex patterns for command blocking (nil = no blocking) + OnActivity func() // Called when channel data flows } // SSHProxy handles proxying SSH connections with credential injection @@ -123,6 +125,29 @@ func (p *SSHProxy) HandleConnection(ctx context.Context, clientConn net.Conn) er // Discard global requests (not needed for basic remote access) go ssh.DiscardRequests(clientRequests) + // SSH keepalive: detect dead connections where TCP goes silent. Probes both sides every 30s + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := util.SSHKeepalive(clientSSHConn, 15*time.Second); err != nil { + log.Info().Err(err).Str("sessionID", sessionID).Msg("SSH keepalive to client failed, tearing down connection") + clientConn.Close() + return + } + if err := util.SSHKeepalive(serverSSHConn, 15*time.Second); err != nil { + log.Info().Err(err).Str("sessionID", sessionID).Msg("SSH keepalive to target failed, tearing down connection") + clientConn.Close() + return + } + case <-ctx.Done(): + return + } + } + }() + // Handle channels from client (this is where actual SSH sessions happen) for newChannel := range clientChannels { go p.handleChannel(ctx, newChannel, serverSSHConn, sessionID) @@ -500,6 +525,10 @@ func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, ses for { n, err := src.Read(buf) if n > 0 { + if p.config.OnActivity != nil { + p.config.OnActivity() + } + // Check if this channel is a binary session (SFTP/SCP) chState.mutex.Lock() isBinary := chState.isBinarySession @@ -744,6 +773,10 @@ func (p *SSHProxy) proxyClientToServerWithBlocking(src io.Reader, dst io.Writer, for { n, err := src.Read(buf) if n > 0 { + if p.config.OnActivity != nil { + p.config.OnActivity() + } + chState.mutex.Lock() isBinary := chState.isBinarySession sftpParser := chState.sftpParser diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 2995e99b..3240e682 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -38,6 +38,7 @@ type GatewayPAMConfig struct { CredentialsManager *session.CredentialsManager SessionUploader *session.SessionUploader GetMongoProxy MongoProxyGetter // Session-level MongoDB proxy sharing + OnActivity func() // Called on data flow } type PAMCapabilitiesResponse struct { @@ -104,19 +105,40 @@ func HandlePAMCancellation(ctx context.Context, conn *tls.Conn, pamConfig *Gatew // Kill the active proxy connection if it exists in the registry if cancelled := cancelSession(pamConfig.SessionId); cancelled { log.Info().Str("sessionId", pamConfig.SessionId).Msg("Active proxy session cancelled via registry") + if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil { + log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session") + } } else { log.Info().Str("sessionId", pamConfig.SessionId).Msg("No active proxy session found in registry (may have already ended)") } - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session") - } - conn.Close() return nil } +// activityConn wraps a net.Conn and calls onActivity on every successful read or write +type activityConn struct { + net.Conn + onActivity func() +} + +func (c *activityConn) Read(b []byte) (int, error) { + n, err := c.Conn.Read(b) + if n > 0 { + c.onActivity() + } + return n, err +} + +func (c *activityConn) Write(b []byte) (int, error) { + n, err := c.Conn.Write(b) + if n > 0 { + c.onActivity() + } + return n, err +} + // compilePolicyPatterns compiles regex pattern strings, logging warnings for any that fail. func compilePolicyPatterns(config *api.PAMPolicyRuleConfig, sessionID string, ruleType string) []*regexp.Regexp { if config == nil || len(config.Patterns) == 0 { @@ -161,10 +183,6 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Time("expiryTime", pamConfig.ExpiryTime). Msg("PAM session expired, closing connection") - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "expiry"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session on expiry") - } - conn.Close() case <-ctx.Done(): // Context cancelled, exit gracefully @@ -177,10 +195,6 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Time("expiryTime", pamConfig.ExpiryTime). Msg("PAM session already expired, closing connection immediately") - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "already_expired"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup already expired PAM session") - } - conn.Close() } }() @@ -242,6 +256,12 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo } } + // Wrap the connection so every read/write resets the idle reaper timer + var handlerConn net.Conn = conn + if pamConfig.OnActivity != nil { + handlerConn = &activityConn{Conn: conn, onActivity: pamConfig.OnActivity} + } + switch pamConfig.ResourceType { case session.ResourceTypePostgres: proxyConfig := handlers.PostgresProxyConfig{ @@ -260,7 +280,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", proxyConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting PostgreSQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMysql: mysqlConfig := mysql.MysqlProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -279,7 +299,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", mysqlConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting MySQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMssql: mssqlConfig := mssql.MssqlProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -298,7 +318,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", mssqlConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting MSSQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeRedis: redisConfig := redis.RedisProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -316,7 +336,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", redisConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting Redis PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeSSH: // Compile command blocking patterns from policy rules var blockedCommandPatterns []*regexp.Regexp @@ -334,6 +354,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo SessionID: pamConfig.SessionId, SessionLogger: sessionLogger, BlockedCommandPatterns: blockedCommandPatterns, + OnActivity: pamConfig.OnActivity, } proxy := ssh.NewSSHProxy(sshConfig) log.Info(). @@ -387,7 +408,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", kubernetesConfig.TargetApiServer). Str("authMethod", credentials.AuthMethod). Msg("Starting Kubernetes PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMongodb: mongoConfig := mongodb.MongoDBProxyConfig{ Host: credentials.ConnectionString, @@ -412,7 +433,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo return fmt.Errorf("MongoDB proxy init: %w", err) } - return proxy.HandleConnection(ctx, conn, sessionLogger) + return proxy.HandleConnection(ctx, handlerConn, sessionLogger) case session.ResourceTypeWindows: if credentials.Port <= 0 || credentials.Port > 65535 { return fmt.Errorf("rdp: target port %d out of range", credentials.Port) @@ -431,7 +452,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("sessionId", pamConfig.SessionId). Str("target", fmt.Sprintf("%s:%d", credentials.Host, credentials.Port)). Msg("Starting RDP PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) default: return fmt.Errorf("unsupported resource type: %s", pamConfig.ResourceType) } diff --git a/packages/util/ssh.go b/packages/util/ssh.go new file mode 100644 index 00000000..d207a75d --- /dev/null +++ b/packages/util/ssh.go @@ -0,0 +1,23 @@ +package util + +import ( + "fmt" + "time" + + "golang.org/x/crypto/ssh" +) + +// SSHKeepalive sends an SSH keepalive request and waits up to timeout for a response +func SSHKeepalive(conn ssh.Conn, timeout time.Duration) error { + errCh := make(chan error, 1) + go func() { + _, _, err := conn.SendRequest("keepalive@openssh.com", true, nil) + errCh <- err + }() + select { + case err := <-errCh: + return err + case <-time.After(timeout): + return fmt.Errorf("no keepalive response within %v", timeout) + } +}