Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 99 additions & 8 deletions packages/gateway-v2/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/Infisical/infisical-merge/packages/api"
Expand Down Expand Up @@ -86,8 +87,9 @@ type GatewayConfig struct {
}

type pamSessionEntry struct {
cancel context.CancelFunc
conn *tls.Conn
cancel context.CancelFunc
conn *tls.Conn
lastActivity atomic.Int64
}

type Gateway struct {
Expand Down Expand Up @@ -166,20 +168,33 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) {
}

// RegisterPAMSession registers an active PAM proxy connection for cancellation support
func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) {
// Returns a function that handlers should call when data flows through the connection
func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) func() {
entry := &pamSessionEntry{cancel: cancel, conn: conn}
entry.lastActivity.Store(time.Now().Unix())

g.pamSessionsMu.Lock()
defer g.pamSessionsMu.Unlock()
g.pamSessions[sessionID] = append(g.pamSessions[sessionID], &pamSessionEntry{cancel: cancel, conn: conn})
g.pamSessions[sessionID] = append(g.pamSessions[sessionID], entry)

return func() {
entry.lastActivity.Store(time.Now().Unix())
}
}

// DeregisterPAMSession removes a specific connection from the session registry.
// Returns true if this was the last connection for the session.
// The MongoDB proxy (if any) is NOT closed here — it persists across connections
// so that subsequent client connections (e.g. mongosh retries) find a warm topology.
// The proxy is cleaned up on session cancellation or gateway shutdown.
func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) {
func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) bool {
g.pamSessionsMu.Lock()
defer g.pamSessionsMu.Unlock()
entries := g.pamSessions[sessionID]

entries, exists := g.pamSessions[sessionID]
if !exists {
return false
}
for i, e := range entries {
if e.conn == conn {
g.pamSessions[sessionID] = append(entries[:i], entries[i+1:]...)
Expand All @@ -188,7 +203,9 @@ func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) {
}
if len(g.pamSessions[sessionID]) == 0 {
delete(g.pamSessions, sessionID)
return true
}
return false
}

// CancelPAMSession kills all active connections for a PAM session
Expand Down Expand Up @@ -264,6 +281,51 @@ func (g *Gateway) closeMongoProxy(sessionID string) {
}
}

const pamIdleTimeout = 30 * time.Minute

// startIdleReaper periodically scans the PAM session registry and cancels
// sessions whose connections have had no data flow for pamIdleTimeout
func (g *Gateway) startIdleReaper(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
g.reapIdleSessions()
}
}
}

func (g *Gateway) reapIdleSessions() {
cutoff := time.Now().Add(-pamIdleTimeout).Unix()

g.pamSessionsMu.Lock()
var stale []string
for sessionID, entries := range g.pamSessions {
allIdle := true
for _, e := range entries {
if e.lastActivity.Load() > cutoff {
allIdle = false
break
}
}
if allIdle {
stale = append(stale, sessionID)
}
}
g.pamSessionsMu.Unlock()

for _, sessionID := range stale {
log.Info().Str("sessionId", sessionID).Dur("idleTimeout", pamIdleTimeout).Msg("Reaping idle PAM session")
g.CancelPAMSession(sessionID)
if err := g.pamSessionUploader.CleanupPAMSession(sessionID, "idle_timeout"); err != nil {
log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to cleanup reaped PAM session")
}
}
}

func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) {
sendHeartbeat := func() error {
if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil {
Expand Down Expand Up @@ -329,6 +391,8 @@ func (g *Gateway) Start(ctx context.Context) error {
// Start session uploader goroutine for PAM
g.pamSessionUploader.Start()

go g.startIdleReaper(ctx)

go func() {
for {
select {
Expand Down Expand Up @@ -489,6 +553,25 @@ func (g *Gateway) handleConnection(client *ssh.Client) error {
client.Close()
}()

// Keepalive on the relay SSH connection. If the relay drops silently,
// this closes the client so the reconnect loop in connectWithRetry kicks in
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := util.SSHKeepalive(client, 15*time.Second); err != nil {
log.Warn().Err(err).Msg("Relay SSH keepalive failed, closing connection")
client.Close()
return
}
case <-g.ctx.Done():
return
}
}
}()

// Process incoming channels with context cancellation support
for {
select {
Expand Down Expand Up @@ -751,15 +834,23 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) {
return
} else if forwardConfig.Mode == ForwardModePAM {
sessionCtx, sessionCancel := context.WithCancel(g.ctx)
g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn)
defer g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn)
touchSession := g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn)
forwardConfig.PAMConfig.OnActivity = touchSession
if err := pam.HandlePAMProxy(sessionCtx, tlsConn, &forwardConfig.PAMConfig, g.httpClient); err != nil {
if err.Error() == "unexpected EOF" {
log.Debug().Err(err).Msg("PAM proxy handler ended with unexpected connection termination")
} else {
log.Error().Err(err).Msg("PAM proxy handler ended with error")
}
}
sessionCancel()
if lastConn := g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn); lastConn {
if err := forwardConfig.PAMConfig.SessionUploader.CleanupPAMSession(
forwardConfig.PAMConfig.SessionId, "connection_closed",
); err != nil {
log.Error().Err(err).Str("sessionId", forwardConfig.PAMConfig.SessionId).Msg("Failed to cleanup PAM session")
}
}
return
} else if forwardConfig.Mode == ForwardModePAMCancellation {
if err := pam.HandlePAMCancellation(g.ctx, tlsConn, &forwardConfig.PAMConfig, g.httpClient, g.CancelPAMSession); err != nil {
Expand Down
33 changes: 33 additions & 0 deletions packages/pam/handlers/ssh/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/Infisical/infisical-merge/packages/pam/session"
"github.com/Infisical/infisical-merge/packages/util"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/ssh"
)
Expand All @@ -26,6 +27,7 @@ type SSHProxyConfig struct {
SessionID string
SessionLogger session.SessionLogger
BlockedCommandPatterns []*regexp.Regexp // Regex patterns for command blocking (nil = no blocking)
OnActivity func() // Called when channel data flows
}

// SSHProxy handles proxying SSH connections with credential injection
Expand Down Expand Up @@ -123,6 +125,29 @@ func (p *SSHProxy) HandleConnection(ctx context.Context, clientConn net.Conn) er
// Discard global requests (not needed for basic remote access)
go ssh.DiscardRequests(clientRequests)

// SSH keepalive: detect dead connections where TCP goes silent. Probes both sides every 30s
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := util.SSHKeepalive(clientSSHConn, 15*time.Second); err != nil {
log.Info().Err(err).Str("sessionID", sessionID).Msg("SSH keepalive to client failed, tearing down connection")
clientConn.Close()
return
}
if err := util.SSHKeepalive(serverSSHConn, 15*time.Second); err != nil {
log.Info().Err(err).Str("sessionID", sessionID).Msg("SSH keepalive to target failed, tearing down connection")
clientConn.Close()
return
}
case <-ctx.Done():
return
}
}
}()

// Handle channels from client (this is where actual SSH sessions happen)
for newChannel := range clientChannels {
go p.handleChannel(ctx, newChannel, serverSSHConn, sessionID)
Expand Down Expand Up @@ -500,6 +525,10 @@ func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, ses
for {
n, err := src.Read(buf)
if n > 0 {
if p.config.OnActivity != nil {
p.config.OnActivity()
}

// Check if this channel is a binary session (SFTP/SCP)
chState.mutex.Lock()
isBinary := chState.isBinarySession
Comment thread
x032205 marked this conversation as resolved.
Expand Down Expand Up @@ -744,6 +773,10 @@ func (p *SSHProxy) proxyClientToServerWithBlocking(src io.Reader, dst io.Writer,
for {
n, err := src.Read(buf)
if n > 0 {
if p.config.OnActivity != nil {
p.config.OnActivity()
}

chState.mutex.Lock()
isBinary := chState.isBinarySession
sftpParser := chState.sftpParser
Expand Down
Loading
Loading