diff --git a/.gitignore b/.gitignore index 2891130d..c654c50c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,5 +14,5 @@ infisical /agent-testing .vscode/ -# PAM CLI session artifacts (local testing only) +.idea/ /session/ diff --git a/packages/pam/handlers/oracle/ATTRIBUTION.md b/packages/pam/handlers/oracle/ATTRIBUTION.md new file mode 100644 index 00000000..8a3a8c60 --- /dev/null +++ b/packages/pam/handlers/oracle/ATTRIBUTION.md @@ -0,0 +1,23 @@ +This package contains code adapted from [sijms/go-ora](https://github.com/sijms/go-ora). + +MIT License + +Copyright (c) 2020 Samy Sultan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/packages/pam/handlers/oracle/constants.go b/packages/pam/handlers/oracle/constants.go new file mode 100644 index 00000000..deeb4c41 --- /dev/null +++ b/packages/pam/handlers/oracle/constants.go @@ -0,0 +1,4 @@ +package oracle + +// Decoy — can be any string. The gateway replaces it with the real credential before Oracle sees it. +const ProxyPasswordPlaceholder = "password" diff --git a/packages/pam/handlers/oracle/o5logon.go b/packages/pam/handlers/oracle/o5logon.go new file mode 100644 index 00000000..ca66ea51 --- /dev/null +++ b/packages/pam/handlers/oracle/o5logon.go @@ -0,0 +1,108 @@ +package oracle + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha512" + "encoding/hex" + "fmt" +) + +const ( + ORA1017InvalidCredentials = 1017 +) + +func PKCS5Padding(cipherText []byte, blockSize int) []byte { + padding := blockSize - len(cipherText)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(cipherText, padtext...) +} + +func generateSpeedyKey(buffer, key []byte, turns int) []byte { + mac := hmac.New(sha512.New, key) + mac.Write(append(buffer, 0, 0, 0, 1)) + firstHash := mac.Sum(nil) + tempHash := make([]byte, len(firstHash)) + copy(tempHash, firstHash) + for index1 := 2; index1 <= turns; index1++ { + mac.Reset() + mac.Write(tempHash) + tempHash = mac.Sum(nil) + for index2 := 0; index2 < 64; index2++ { + firstHash[index2] = firstHash[index2] ^ tempHash[index2] + } + } + return firstHash +} + +func decryptSessionKey(padding bool, encKey []byte, sessionKeyHex string) ([]byte, error) { + result, err := hex.DecodeString(sessionKeyHex) + if err != nil { + return nil, err + } + blk, err := aes.NewCipher(encKey) + if err != nil { + return nil, err + } + dec := cipher.NewCBCDecrypter(blk, make([]byte, 16)) + output := make([]byte, len(result)) + dec.CryptBlocks(output, result) + cutLen := 0 + if padding { + num := int(output[len(output)-1]) + if num < dec.BlockSize() { + apply := true + for x := len(output) - num; x < len(output); x++ { + if output[x] != uint8(num) { + apply = false + break + } + } + if apply { + cutLen = int(output[len(output)-1]) + } + } + } + return output[:len(output)-cutLen], nil +} + +func encryptSessionKey(padding bool, encKey []byte, sessionKey []byte) (string, error) { + blk, err := aes.NewCipher(encKey) + if err != nil { + return "", err + } + enc := cipher.NewCBCEncrypter(blk, make([]byte, 16)) + originalLen := len(sessionKey) + sessionKey = PKCS5Padding(sessionKey, blk.BlockSize()) + output := make([]byte, len(sessionKey)) + enc.CryptBlocks(output, sessionKey) + if !padding { + return fmt.Sprintf("%X", output[:originalLen]), nil + } + return fmt.Sprintf("%X", output), nil +} + +func encryptPassword(password, key []byte, padding bool) (string, error) { + buff1 := make([]byte, 0x10) + if _, err := rand.Read(buff1); err != nil { + return "", err + } + buffer := append(buff1, password...) + return encryptSessionKey(padding, key, buffer) +} + +func deriveServerKey(password string, salt []byte, vGenCount int) (key []byte, speedy []byte, err error) { + message := append([]byte(nil), salt...) + message = append(message, []byte("AUTH_PBKDF2_SPEEDY_KEY")...) + speedy = generateSpeedyKey(message, []byte(password), vGenCount) + + buffer := append([]byte(nil), speedy...) + buffer = append(buffer, salt...) + h := sha512.New() + h.Write(buffer) + key = h.Sum(nil)[:32] + return +} diff --git a/packages/pam/handlers/oracle/o5logon_server.go b/packages/pam/handlers/oracle/o5logon_server.go new file mode 100644 index 00000000..7df71436 --- /dev/null +++ b/packages/pam/handlers/oracle/o5logon_server.go @@ -0,0 +1,163 @@ +package oracle + +import ( + "fmt" + "net" +) + +const ( + TTCMsgAuthRequest = 0x03 + TTCMsgError = 0x04 +) + +const ( + AuthSubOpPhaseOne = 0x76 + AuthSubOpPhaseTwo = 0x73 +) + +type AuthPhaseTwo struct { + EClientSessKey string + EPassword string +} + +func readDataPayload(conn net.Conn, use32BitLen bool) ([]byte, error) { + raw, err := ReadFullPacket(conn, use32BitLen) + if err != nil { + return nil, err + } + if PacketTypeOf(raw) == PacketTypeMarker { + return readDataPayload(conn, use32BitLen) + } + if PacketTypeOf(raw) != PacketTypeData { + return nil, fmt.Errorf("expected DATA packet, got type=%d", raw[4]) + } + pkt, err := ParseDataPacket(raw, use32BitLen) + if err != nil { + return nil, err + } + return pkt.Payload, nil +} + +func writeDataPayload(conn net.Conn, payload []byte, use32BitLen bool) error { + d := &DataPacket{Payload: payload} + _, err := conn.Write(d.Bytes(use32BitLen)) + return err +} + +func ParseAuthPhaseTwo(payload []byte) (*AuthPhaseTwo, error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, err + } + if op != TTCMsgAuthRequest { + return nil, fmt.Errorf("phase2 unexpected opcode 0x%02X", op) + } + sub, err := r.GetByte() + if err != nil { + return nil, err + } + if sub != AuthSubOpPhaseTwo { + return nil, fmt.Errorf("phase2 unexpected sub-op 0x%02X", sub) + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + + out := &AuthPhaseTwo{} + + hasUser, err := r.GetByte() + if err != nil { + return nil, err + } + var userLen int + if hasUser == 1 { + userLen, err = r.GetInt(4, true, true) + if err != nil { + return nil, err + } + } else { + if _, err := r.GetByte(); err != nil { + return nil, err + } + } + + if _, err := r.GetInt(4, true, true); err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + count, err := r.GetInt(4, true, true) + if err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + if hasUser == 1 && userLen > 0 { + // go-ora prefixes username with CLR length byte; JDBC thin sends it raw. + peek, perr := r.PeekByte() + if perr != nil { + return nil, fmt.Errorf("peek phase2 username: %w", perr) + } + if int(peek) == userLen && peek < 0x20 { + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("consume phase2 username length prefix: %w", err) + } + } + if _, err := r.GetBytes(userLen); err != nil { + return nil, fmt.Errorf("read phase2 username bytes: %w", err) + } + } + + for i := 0; i < count; i++ { + k, v, _, err := r.GetKeyVal() + if err != nil { + return nil, fmt.Errorf("phase2 KVP #%d: %w", i, err) + } + switch string(k) { + case "AUTH_SESSKEY": + out.EClientSessKey = string(v) + case "AUTH_PASSWORD": + out.EPassword = string(v) + } + } + return out, nil +} + +func BuildErrorPacket(oraCode int, message string) []byte { + b := NewTTCBuilder() + b.PutBytes(TTCMsgError) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(int64(oraCode), 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutString(message) + b.PutInt(0, 2, true, true) + return b.Bytes() +} + +func WriteErrorToClient(conn net.Conn, oraCode int, message string, use32BitLen bool) error { + return writeDataPayload(conn, BuildErrorPacket(oraCode, message), use32BitLen) +} diff --git a/packages/pam/handlers/oracle/proxy.go b/packages/pam/handlers/oracle/proxy.go new file mode 100644 index 00000000..71916b6a --- /dev/null +++ b/packages/pam/handlers/oracle/proxy.go @@ -0,0 +1,64 @@ +package oracle + +import ( + "context" + "crypto/tls" + "fmt" + "net" + + "github.com/Infisical/infisical-merge/packages/pam/session" +) + +type OracleProxyConfig struct { + TargetAddr string + InjectUsername string + InjectPassword string + InjectDatabase string + EnableTLS bool + TLSConfig *tls.Config + SessionID string + SessionLogger session.SessionLogger +} + +type OracleProxy struct { + config OracleProxyConfig +} + +func NewOracleProxy(config OracleProxyConfig) *OracleProxy { + return &OracleProxy{config: config} +} + +func (p *OracleProxy) HandleConnection(ctx context.Context, clientConn net.Conn) error { + return p.handleConnectionProxied(ctx, clientConn) +} + +func relayWithTap(src, dst net.Conn, tap *QueryExtractor, errCh chan<- error) { + buf := make([]byte, 32*1024) + for { + n, err := src.Read(buf) + if n > 0 { + if _, werr := dst.Write(buf[:n]); werr != nil { + errCh <- werr + return + } + tap.Feed(buf[:n]) + } + if err != nil { + errCh <- err + return + } + } +} + +func splitHostPort(addr string) (string, int, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return "", 0, err + } + var port int + _, err = fmt.Sscanf(portStr, "%d", &port) + if err != nil { + return "", 0, fmt.Errorf("bad port %q: %w", portStr, err) + } + return host, port, nil +} diff --git a/packages/pam/handlers/oracle/proxy_auth.go b/packages/pam/handlers/oracle/proxy_auth.go new file mode 100644 index 00000000..325f32b1 --- /dev/null +++ b/packages/pam/handlers/oracle/proxy_auth.go @@ -0,0 +1,800 @@ +package oracle + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "net" + "strconv" + "time" + + "github.com/rs/zerolog/log" +) + +func (p *OracleProxy) handleConnectionProxied(ctx context.Context, clientConn net.Conn) error { + defer clientConn.Close() + defer func() { + if err := p.config.SessionLogger.Close(); err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to close session logger") + } + }() + + log.Info().Str("sessionID", p.config.SessionID).Str("target", p.config.TargetAddr).Msg("Oracle PAM session started (proxied auth)") + + // Keep raw TCP ref — Oracle TCPS may require a second TLS handshake mid-flow. + rawUpstream, tlsUpstream, err := dialUpstreamRaw(ctx, p.config) + if err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to dial Oracle upstream") + _ = WriteRefuseToClient(clientConn, "(DESCRIPTION=(ERR=12564)(VSNNUM=0)(ERROR_STACK=(ERROR=(CODE=12564)(EMFI=4))))") + return fmt.Errorf("upstream dial: %w", err) + } + var upstreamConn net.Conn + if tlsUpstream != nil { + upstreamConn = tlsUpstream + } else { + upstreamConn = rawUpstream + } + defer func() { upstreamConn.Close() }() + + connectRaw, err := ReadFullPacket(clientConn, false) + if err != nil { + return fmt.Errorf("read client CONNECT: %w", err) + } + if PacketTypeOf(connectRaw) != PacketTypeConnect { + return fmt.Errorf("expected CONNECT, got type=%d", connectRaw[4]) + } + if p.config.InjectDatabase == "" { + return fmt.Errorf("InjectDatabase (service name) is required but empty") + } + connectRaw = rewriteConnectServiceName(connectRaw, p.config.InjectDatabase) + if _, err := upstreamConn.Write(connectRaw); err != nil { + return fmt.Errorf("forward CONNECT: %w", err) + } + + // go-ora sends connect-data as a separate 16-bit-framed packet when > 230 bytes. + connectDataInline := true + if len(connectRaw) >= 28 { + cdLen := int(binary.BigEndian.Uint16(connectRaw[24:26])) + cdOff := int(binary.BigEndian.Uint16(connectRaw[26:28])) + if cdLen > 0 && cdOff+cdLen > len(connectRaw) { + connectDataInline = false + } + } + + var acceptRaw []byte + resendConsumedSupplement := false + for attempt := 0; acceptRaw == nil; attempt++ { + pkt, err := ReadFullPacket(upstreamConn, false) + if err != nil { + return fmt.Errorf("read upstream handshake packet (attempt %d): %w", attempt, err) + } + pktType := PacketTypeOf(pkt) + var origFlag byte + if len(pkt) > 5 { + origFlag = pkt[5] + } + log.Info().Str("sessionID", p.config.SessionID).Uint8("pktType", uint8(pktType)).Int("pktLen", len(pkt)).Uint8("flag", origFlag).Msg("Proxy: upstream handshake packet") + + // RESEND flag 0x08: tear down current TLS, do a fresh handshake on the raw socket. + if p.config.EnableTLS && pktType == PacketTypeResend && origFlag&0x08 != 0 { + tc, terr := upgradeToTLS(ctx, rawUpstream, p.config) + if terr != nil { + return fmt.Errorf("upstream TLS upgrade after RESEND(flag=0x08): %w", terr) + } + upstreamConn = tc + log.Info().Str("sessionID", p.config.SessionID).Str("tlsVersion", tlsVersionString(tc.ConnectionState().Version)).Str("cipher", tls.CipherSuiteName(tc.ConnectionState().CipherSuite)).Msg("Proxy: upstream TLS re-handshook on RESEND(flag=0x08)") + } + + // Mask byte 5 so thin clients don't try TLS upgrade on plain TCP. + if p.config.EnableTLS && len(pkt) > 5 { + pkt[5] = 0x00 + } + if _, werr := clientConn.Write(pkt); werr != nil { + return fmt.Errorf("forward upstream handshake packet: %w", werr) + } + switch pktType { + case PacketTypeAccept: + acceptRaw = pkt + case PacketTypeRefuse: + return fmt.Errorf("upstream REFUSE during handshake") + case PacketTypeRedirect: + return fmt.Errorf("upstream REDIRECT during handshake (not supported)") + case PacketTypeResend: + resendConsumedSupplement = true + supplement, err := ReadFullPacket(clientConn, false) + if err != nil { + return fmt.Errorf("read client supplement after RESEND: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Int("supplementLen", len(supplement)).Uint8("supplType", uint8(PacketTypeOf(supplement))).Msg("Proxy: forwarding client supplement after RESEND") + if _, werr := upstreamConn.Write(supplement); werr != nil { + return fmt.Errorf("forward client supplement: %w", werr) + } + } + } + + var acceptVersion uint16 + if len(acceptRaw) >= 10 { + acceptVersion = binary.BigEndian.Uint16(acceptRaw[8:10]) + } + use32Bit := acceptVersion >= 315 + log.Info().Str("sessionID", p.config.SessionID).Uint16("acceptVersion", acceptVersion).Bool("use32Bit", use32Bit).Msg("Proxy: ACCEPT forwarded") + + if !connectDataInline && !resendConsumedSupplement { + supplement, err := ReadFullPacket(clientConn, false) + if err != nil { + return fmt.Errorf("read connect-data supplement: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Int("supplementLen", len(supplement)).Msg("Proxy: forwarding connect-data supplement") + if _, err := upstreamConn.Write(supplement); err != nil { + return fmt.Errorf("forward connect-data supplement: %w", err) + } + } + + p1Payload, err := proxyUntilAuthRequest(clientConn, upstreamConn, use32Bit, p.config.SessionID) + if err != nil { + return fmt.Errorf("pre-auth proxy: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Int("p1Len", len(p1Payload)).Msg("Proxy: auth-request boundary reached") + + p1Forward := p1Payload + if p.config.InjectUsername != "" { + rewritten, rerr := rewritePhase1User(p1Payload, p.config.InjectUsername) + if rerr != nil { + return fmt.Errorf("rewrite phase 1 username: %w", rerr) + } + p1Forward = rewritten + } + if err := writeDataPayload(upstreamConn, p1Forward, use32Bit); err != nil { + return fmt.Errorf("forward phase 1 request: %w", err) + } + + p1RespUpstream, err := readDataPayload(upstreamConn, use32Bit) + if err != nil { + return fmt.Errorf("read upstream phase 1 response: %w", err) + } + state, p1RespTranslated, err := translatePhase1Response(p1RespUpstream, p.config.InjectPassword) + if err != nil { + _ = WriteErrorToClient(clientConn, ORA1017InvalidCredentials, "ORA-01017: invalid username/password; logon denied", use32Bit) + return fmt.Errorf("translate phase 1 response: %w", err) + } + if err := writeDataPayload(clientConn, p1RespTranslated, use32Bit); err != nil { + return fmt.Errorf("write translated phase 1 response: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-1 response translated and forwarded") + + p2ReqClient, err := readDataPayload(clientConn, use32Bit) + if err != nil { + return fmt.Errorf("read client phase 2 request: %w", err) + } + p2ReqTranslated, err := translatePhase2Request(p2ReqClient, state, p.config.InjectPassword) + if err != nil { + _ = WriteErrorToClient(clientConn, ORA1017InvalidCredentials, "ORA-01017: invalid username/password; logon denied", use32Bit) + return fmt.Errorf("translate phase 2 request: %w", err) + } + // Oracle cross-checks phase-2 username against phase-1. + if p.config.InjectUsername != "" { + rewritten, rerr := rewritePhase2User(p2ReqTranslated, p.config.InjectUsername) + if rerr != nil { + return fmt.Errorf("rewrite phase 2 username: %w", rerr) + } + p2ReqTranslated = rewritten + } + if err := writeDataPayload(upstreamConn, p2ReqTranslated, use32Bit); err != nil { + return fmt.Errorf("forward phase 2 request: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-2 request translated and forwarded") + + // AUTH_SVR_RESPONSE is keyed on session material (not password) — forward unchanged. + p2RespRaw, err := ReadFullPacket(upstreamConn, use32Bit) + if err != nil { + return fmt.Errorf("read upstream phase 2 response: %w", err) + } + if _, err := clientConn.Write(p2RespRaw); err != nil { + return fmt.Errorf("forward phase 2 response: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-2 response forwarded; client authenticated") + + c2u, u2c := NewQueryExtractorPair(p.config.SessionLogger, p.config.SessionID, use32Bit) + defer c2u.Stop() + defer u2c.Stop() + + errCh := make(chan error, 2) + go relayWithTap(clientConn, upstreamConn, c2u, errCh) + go relayWithTap(upstreamConn, clientConn, u2c, errCh) + + select { + case rerr := <-errCh: + if rerr != nil && rerr != io.EOF { + log.Debug().Err(rerr).Str("sessionID", p.config.SessionID).Msg("Oracle relay ended") + } + case <-ctx.Done(): + log.Info().Str("sessionID", p.config.SessionID).Msg("Oracle session cancelled by context") + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Oracle PAM session ended") + return nil +} + +// Includes legacy RSA-CBC suites needed by Oracle 19c / AWS RDS. +var oracleUpstreamCiphers = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, +} + +// TLS 1.0–1.2 only: Oracle TCPS has no TLS-1.3 restart mechanism; RDS negotiates down to 1.0. +func buildOracleTLSConfig(base *tls.Config, host string) *tls.Config { + cfg := base.Clone() + if cfg.ServerName == "" { + cfg.ServerName = host + } + cfg.MinVersion = tls.VersionTLS10 + cfg.MaxVersion = tls.VersionTLS12 + cfg.CipherSuites = oracleUpstreamCiphers + return cfg +} + +func dialUpstreamRaw(ctx context.Context, cfg OracleProxyConfig) (rawConn net.Conn, tlsConn *tls.Conn, err error) { + host, _, err := splitHostPort(cfg.TargetAddr) + if err != nil { + return nil, nil, fmt.Errorf("invalid target addr: %w", err) + } + d := &net.Dialer{Timeout: 15 * time.Second} + rawConn, err = d.DialContext(ctx, "tcp", cfg.TargetAddr) + if err != nil { + return nil, nil, err + } + if !cfg.EnableTLS { + return rawConn, nil, nil + } + if cfg.TLSConfig == nil { + rawConn.Close() + return nil, nil, fmt.Errorf("upstream TLS requested but no TLSConfig provided") + } + tlsCfg := buildOracleTLSConfig(cfg.TLSConfig, host) + tc := tls.Client(rawConn, tlsCfg) + if err := tc.HandshakeContext(ctx); err != nil { + rawConn.Close() + return nil, nil, fmt.Errorf("upstream TLS handshake: %w", err) + } + return rawConn, tc, nil +} + +func tlsVersionString(v uint16) string { + switch v { + case tls.VersionTLS10: + return "TLS1.0" + case tls.VersionTLS11: + return "TLS1.1" + case tls.VersionTLS12: + return "TLS1.2" + case tls.VersionTLS13: + return "TLS1.3" + default: + return fmt.Sprintf("0x%04x", v) + } +} + +func upgradeToTLS(ctx context.Context, rawConn net.Conn, cfg OracleProxyConfig) (*tls.Conn, error) { + host, _, err := splitHostPort(cfg.TargetAddr) + if err != nil { + return nil, fmt.Errorf("invalid target addr: %w", err) + } + tlsCfg := buildOracleTLSConfig(cfg.TLSConfig, host) + tc := tls.Client(rawConn, tlsCfg) + if err := tc.HandshakeContext(ctx); err != nil { + return nil, fmt.Errorf("upstream TLS handshake: %w", err) + } + return tc, nil +} + +func proxyUntilAuthRequest(client, upstream net.Conn, use32Bit bool, sessionID string) ([]byte, error) { + type result struct { + payload []byte + err error + } + done := make(chan result, 2) + stop := make(chan struct{}) + + go func() { + for { + select { + case <-stop: + return + default: + } + pkt, err := ReadFullPacket(upstream, use32Bit) + if err != nil { + select { + case done <- result{err: fmt.Errorf("read upstream: %w", err)}: + default: + } + return + } + if _, werr := client.Write(pkt); werr != nil { + select { + case done <- result{err: fmt.Errorf("write client: %w", werr)}: + default: + } + return + } + log.Debug().Str("sessionID", sessionID).Uint8("type", uint8(PacketTypeOf(pkt))).Int("len", len(pkt)).Msg("Proxy pre-auth: upstream → client") + } + }() + + go func() { + for { + select { + case <-stop: + return + default: + } + pkt, err := ReadFullPacket(client, use32Bit) + if err != nil { + select { + case done <- result{err: fmt.Errorf("read client: %w", err)}: + default: + } + return + } + pktType := PacketTypeOf(pkt) + if pktType == PacketTypeData { + payload, perr := extractDataPayload(pkt) + if perr == nil && len(payload) >= 2 && + payload[0] == TTCMsgAuthRequest && payload[1] == AuthSubOpPhaseOne { + select { + case done <- result{payload: payload}: + default: + } + return + } + } + if _, werr := upstream.Write(pkt); werr != nil { + select { + case done <- result{err: fmt.Errorf("write upstream: %w", werr)}: + default: + } + return + } + log.Debug().Str("sessionID", sessionID).Uint8("type", uint8(pktType)).Int("len", len(pkt)).Msg("Proxy pre-auth: client → upstream") + } + }() + + res := <-done + close(stop) + // Unblock the other goroutine so it doesn't steal the phase-1 response. + if uc, ok := upstream.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = uc.SetReadDeadline(time.Now().Add(-1 * time.Second)) + } + time.Sleep(50 * time.Millisecond) + if uc, ok := upstream.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = uc.SetReadDeadline(time.Time{}) + } + if res.err != nil { + return nil, res.err + } + return res.payload, nil +} + +func extractDataPayload(pkt []byte) ([]byte, error) { + const headerLen = 10 + if len(pkt) < headerLen { + return nil, fmt.Errorf("packet too short: %d", len(pkt)) + } + return pkt[headerLen:], nil +} + +func rewriteAuthRequestUser(payload []byte, expectedSubOp byte, newUser string) ([]byte, error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, fmt.Errorf("opcode: %w", err) + } + if op != TTCMsgAuthRequest { + return nil, fmt.Errorf("unexpected opcode 0x%02X", op) + } + sub, err := r.GetByte() + if err != nil { + return nil, err + } + if sub != expectedSubOp { + return nil, fmt.Errorf("unexpected sub-op 0x%02X (want 0x%02X)", sub, expectedSubOp) + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + + hasUser, err := r.GetByte() + if err != nil { + return nil, err + } + if hasUser != 1 { + return payload, nil + } + origUserLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, fmt.Errorf("userLen: %w", err) + } + if origUserLen <= 0 { + return payload, nil + } + + middleStart := r.Pos() + + if _, err := r.GetInt(4, true, true); err != nil { + return nil, fmt.Errorf("mode: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker after mode: %w", err) + } + if _, err := r.GetInt(4, true, true); err != nil { + return nil, fmt.Errorf("count: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker 1: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker 2: %w", err) + } + middleEnd := r.Pos() + + // go-ora prefixes user bytes with a CLR-length byte; JDBC thin omits it. + peek, perr := r.PeekByte() + if perr != nil { + return nil, fmt.Errorf("peek user: %w", perr) + } + usedCLRPrefix := int(peek) == origUserLen && peek < 0x20 + if usedCLRPrefix { + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("consume user CLR length: %w", err) + } + } + if _, err := r.GetBytes(origUserLen); err != nil { + return nil, fmt.Errorf("user bytes: %w", err) + } + userEnd := r.Pos() + + newUserBytes := []byte(newUser) + newUserLen := len(newUserBytes) + + out := make([]byte, 0, len(payload)+16) + out = append(out, payload[:3]...) + out = append(out, 0x01) + lb := NewTTCBuilder() + lb.PutInt(int64(newUserLen), 4, true, true) + out = append(out, lb.Bytes()...) + out = append(out, payload[middleStart:middleEnd]...) + if usedCLRPrefix { + out = append(out, byte(newUserLen)) + } + out = append(out, newUserBytes...) + out = append(out, payload[userEnd:]...) + return out, nil +} + +func rewriteConnectServiceName(pkt []byte, newName string) []byte { + marker := []byte("SERVICE_NAME=") + idx := bytes.Index(pkt, marker) + if idx < 0 { + return pkt + } + valStart := idx + len(marker) + valEnd := bytes.IndexByte(pkt[valStart:], ')') + if valEnd < 0 { + return pkt + } + valEnd += valStart + + oldVal := pkt[valStart:valEnd] + newVal := []byte(newName) + if bytes.Equal(oldVal, newVal) { + return pkt + } + + out := make([]byte, 0, len(pkt)+len(newVal)-len(oldVal)) + out = append(out, pkt[:valStart]...) + out = append(out, newVal...) + out = append(out, pkt[valEnd:]...) + + binary.BigEndian.PutUint16(out[0:2], uint16(len(out))) + if len(out) >= 26 { + oldCDLen := binary.BigEndian.Uint16(pkt[24:26]) + binary.BigEndian.PutUint16(out[24:26], uint16(int(oldCDLen)+len(newVal)-len(oldVal))) + } + return out +} + +func rewritePhase1User(payload []byte, newUser string) ([]byte, error) { + return rewriteAuthRequestUser(payload, AuthSubOpPhaseOne, newUser) +} + +func rewritePhase2User(payload []byte, newUser string) ([]byte, error) { + return rewriteAuthRequestUser(payload, AuthSubOpPhaseTwo, newUser) +} + +type ProxyAuthState struct { + Salt []byte + Pbkdf2CSKSalt string + Pbkdf2VGenCount int + Pbkdf2SDerCount int + RealKey []byte + PlaceholderKey []byte + ServerSessKey []byte +} + +func translatePhase1Response(payload []byte, realPassword string) (*ProxyAuthState, []byte, error) { + kvs, trailer, err := parseAuthRespKVPList(payload) + if err != nil { + return nil, nil, fmt.Errorf("parse upstream phase 1: %w", err) + } + + var eSessKey, vfrData, cskSalt, vGenStr, sDerStr string + for _, kv := range kvs { + switch kv.Key { + case "AUTH_SESSKEY": + eSessKey = kv.Value + case "AUTH_VFR_DATA": + vfrData = kv.Value + case "AUTH_PBKDF2_CSK_SALT": + cskSalt = kv.Value + case "AUTH_PBKDF2_VGEN_COUNT": + vGenStr = kv.Value + case "AUTH_PBKDF2_SDER_COUNT": + sDerStr = kv.Value + } + } + if eSessKey == "" || vfrData == "" { + return nil, nil, fmt.Errorf("upstream phase 1 missing AUTH_SESSKEY or AUTH_VFR_DATA") + } + salt, err := hex.DecodeString(vfrData) + if err != nil { + return nil, nil, fmt.Errorf("decode salt: %w", err) + } + vGen, _ := strconv.Atoi(vGenStr) + if vGen == 0 { + vGen = 4096 + } + sDer, _ := strconv.Atoi(sDerStr) + if sDer == 0 { + sDer = 3 + } + + realKey, _, err := deriveServerKey(realPassword, salt, vGen) + if err != nil { + return nil, nil, fmt.Errorf("derive real key: %w", err) + } + placeholderKey, _, err := deriveServerKey(ProxyPasswordPlaceholder, salt, vGen) + if err != nil { + return nil, nil, fmt.Errorf("derive placeholder key: %w", err) + } + + serverSessKey, err := decryptSessionKey(false, realKey, eSessKey) + if err != nil { + return nil, nil, fmt.Errorf("decrypt upstream server session key: %w", err) + } + newESessKey, err := encryptSessionKey(false, placeholderKey, serverSessKey) + if err != nil { + return nil, nil, fmt.Errorf("re-encrypt server session key: %w", err) + } + + for i := range kvs { + if kvs[i].Key == "AUTH_SESSKEY" { + kvs[i].Value = newESessKey + break + } + } + + rebuilt := rebuildAuthRespPayload(kvs, trailer) + + state := &ProxyAuthState{ + Salt: salt, + Pbkdf2CSKSalt: cskSalt, + Pbkdf2VGenCount: vGen, + Pbkdf2SDerCount: sDer, + RealKey: realKey, + PlaceholderKey: placeholderKey, + ServerSessKey: serverSessKey, + } + return state, rebuilt, nil +} + +func translatePhase2Request(payload []byte, state *ProxyAuthState, realPassword string) ([]byte, error) { + p2, err := ParseAuthPhaseTwo(payload) + if err != nil { + return nil, fmt.Errorf("parse client phase 2: %w", err) + } + + if p2.EClientSessKey == "" || p2.EPassword == "" { + return nil, fmt.Errorf("client phase 2 missing AUTH_SESSKEY or AUTH_PASSWORD") + } + + clientSessKey, err := decryptSessionKey(false, state.PlaceholderKey, p2.EClientSessKey) + if err != nil { + return nil, fmt.Errorf("decrypt client session key: %w", err) + } + if len(clientSessKey) != len(state.ServerSessKey) { + return nil, fmt.Errorf("client session key length mismatch: got %d want %d", len(clientSessKey), len(state.ServerSessKey)) + } + newEClientSessKey, err := encryptSessionKey(false, state.RealKey, clientSessKey) + if err != nil { + return nil, fmt.Errorf("re-encrypt client session key: %w", err) + } + + // encKey derives from session keys + CSK salt, not the password. + encKey, err := deriveProxyPasswordEncKey(clientSessKey, state.ServerSessKey, state.Pbkdf2CSKSalt, state.Pbkdf2SDerCount) + if err != nil { + return nil, fmt.Errorf("derive enc key: %w", err) + } + newEPassword, err := encryptPassword([]byte(realPassword), encKey, true) + if err != nil { + return nil, fmt.Errorf("encrypt real password: %w", err) + } + + rebuilt, err := rebuildPhase2Request(payload, newEClientSessKey, newEPassword) + if err != nil { + return nil, fmt.Errorf("rebuild phase 2: %w", err) + } + return rebuilt, nil +} + +func deriveProxyPasswordEncKey(clientSessKey, serverSessKey []byte, pbkdf2CSKSaltHex string, sderCount int) ([]byte, error) { + buffer := append([]byte(nil), clientSessKey...) + buffer = append(buffer, serverSessKey...) + keyBuffer := []byte(fmt.Sprintf("%X", buffer)) + cskSalt, err := hex.DecodeString(pbkdf2CSKSaltHex) + if err != nil { + return nil, fmt.Errorf("decode pbkdf2 salt: %w", err) + } + full := generateSpeedyKey(cskSalt, keyBuffer, sderCount) + if len(full) < 32 { + return nil, fmt.Errorf("speedy key too short: %d", len(full)) + } + return full[:32], nil +} + +type parsedKVP struct { + Key string + Value string + Flag int +} + +func parseAuthRespKVPList(payload []byte) (kvs []parsedKVP, trailer []byte, err error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, nil, err + } + if op != 0x08 { + return nil, nil, fmt.Errorf("expected auth response opcode 0x08, got 0x%02X", op) + } + dictLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("dict len: %w", err) + } + for i := 0; i < dictLen; i++ { + keyLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d key len: %w", i, err) + } + var keyBytes []byte + if keyLen > 0 { + keyBytes, err = r.GetClr() + if err != nil { + return nil, nil, fmt.Errorf("kvp %d key: %w", i, err) + } + if len(keyBytes) > keyLen { + keyBytes = keyBytes[:keyLen] + } + } + valLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d val len: %w", i, err) + } + var valBytes []byte + if valLen > 0 { + valBytes, err = r.GetClr() + if err != nil { + return nil, nil, fmt.Errorf("kvp %d val: %w", i, err) + } + if len(valBytes) > valLen { + valBytes = valBytes[:valLen] + } + } + flag, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d flag: %w", i, err) + } + kvs = append(kvs, parsedKVP{ + Key: string(bytes.TrimRight(keyBytes, "\x00")), + Value: string(valBytes), + Flag: flag, + }) + } + trailer = make([]byte, r.Remaining()) + rem, _ := r.GetBytes(r.Remaining()) + copy(trailer, rem) + return kvs, trailer, nil +} + +func rebuildAuthRespPayload(kvs []parsedKVP, trailer []byte) []byte { + b := NewTTCBuilder() + b.PutBytes(0x08) + b.PutUint(uint64(len(kvs)), 4, true, true) + for _, kv := range kvs { + b.PutKeyValString(kv.Key, kv.Value, uint32(kv.Flag)) + } + b.PutBytes(trailer...) + return b.Bytes() +} + +func rebuildPhase2Request(payload []byte, newESessKey, newEPassword string) ([]byte, error) { + out := make([]byte, 0, len(payload)+128) + out = append(out, payload...) + + out, err := replaceKVPValue(out, "AUTH_SESSKEY", newESessKey) + if err != nil { + return nil, fmt.Errorf("replace AUTH_SESSKEY: %w", err) + } + out, err = replaceKVPValue(out, "AUTH_PASSWORD", newEPassword) + if err != nil { + return nil, fmt.Errorf("replace AUTH_PASSWORD: %w", err) + } + return out, nil +} + +func replaceKVPValue(payload []byte, key, newValue string) ([]byte, error) { + keyBytes := []byte(key) + idx := bytes.Index(payload, keyBytes) + if idx < 0 { + return nil, fmt.Errorf("key %q not found", key) + } + pos := idx + len(keyBytes) + if pos >= len(payload) { + return nil, fmt.Errorf("truncated after key") + } + vSizeByte := payload[pos] + pos++ + var vLen int + if vSizeByte == 0 { + vLen = 0 + } else if int(vSizeByte) <= 8 { + for i := 0; i < int(vSizeByte); i++ { + vLen = (vLen << 8) | int(payload[pos+i]) + } + pos += int(vSizeByte) + } else { + return nil, fmt.Errorf("invalid val_len size byte %d", vSizeByte) + } + if vLen > 0 { + if pos >= len(payload) || int(payload[pos]) != vLen { + return nil, fmt.Errorf("CLR length byte mismatch for %q: got %d want %d", key, payload[pos], vLen) + } + pos++ + valBodyStart := pos + valBodyEnd := valBodyStart + vLen + // PutClr handles chunked 0xFE form for values > 0xFC bytes. + newVal := []byte(newValue) + vb := NewTTCBuilder() + vb.PutUint(uint64(len(newVal)), 4, true, true) + vb.PutClr(newVal) + newValSection := vb.Bytes() + oldStart := idx + len(keyBytes) + oldEnd := valBodyEnd + out := make([]byte, 0, len(payload)+len(newValSection)) + out = append(out, payload[:oldStart]...) + out = append(out, newValSection...) + out = append(out, payload[oldEnd:]...) + return out, nil + } + return payload, fmt.Errorf("unexpected empty value for %q", key) +} diff --git a/packages/pam/handlers/oracle/query_logger.go b/packages/pam/handlers/oracle/query_logger.go new file mode 100644 index 00000000..2860b452 --- /dev/null +++ b/packages/pam/handlers/oracle/query_logger.go @@ -0,0 +1,276 @@ +package oracle + +import ( + "bytes" + "encoding/binary" + "fmt" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/rs/zerolog/log" +) + +const ( + ttcFuncOALL8 = 0x5E + ttcFuncOCOMMIT = 0x0E + ttcFuncORLLBK = 0x0F + ttcMsgFunction = 0x03 +) + +type pendingQuery struct { + sql string + timestamp time.Time +} + +// Best-effort SQL extraction from the byte stream. +type QueryExtractor struct { + logger session.SessionLogger + sessionID string + direction string + ch chan []byte + stopCh chan struct{} + wg sync.WaitGroup + use32Bit bool + pair *pairState +} + +type pairState struct { + mu sync.Mutex + pending *pendingQuery +} + +func NewQueryExtractorPair(logger session.SessionLogger, sessionID string, use32Bit bool) (clientToUpstream, upstreamToClient *QueryExtractor) { + p := &pairState{} + clientToUpstream = newExtractor(logger, sessionID, "client->upstream", use32Bit, p) + upstreamToClient = newExtractor(logger, sessionID, "upstream->client", use32Bit, p) + return +} + +func newExtractor(logger session.SessionLogger, sessionID, direction string, use32Bit bool, pair *pairState) *QueryExtractor { + e := &QueryExtractor{ + logger: logger, + sessionID: sessionID, + direction: direction, + ch: make(chan []byte, 64), + stopCh: make(chan struct{}), + use32Bit: use32Bit, + pair: pair, + } + e.wg.Add(1) + go e.loop() + return e +} + +func (e *QueryExtractor) Feed(data []byte) { + if len(data) == 0 { + return + } + cp := make([]byte, len(data)) + copy(cp, data) + select { + case e.ch <- cp: + default: + } +} + +func (e *QueryExtractor) Stop() { + close(e.stopCh) + e.wg.Wait() +} + +func (e *QueryExtractor) loop() { + defer e.wg.Done() + var buffer bytes.Buffer + + for { + select { + case <-e.stopCh: + return + case chunk := <-e.ch: + buffer.Write(chunk) + e.drain(&buffer) + } + } +} + +func (e *QueryExtractor) drain(buf *bytes.Buffer) { + for { + if buf.Len() < 8 { + return + } + head := buf.Bytes()[:8] + var length uint32 + if e.use32Bit { + length = binary.BigEndian.Uint32(head) + } else { + length = uint32(binary.BigEndian.Uint16(head)) + } + if length < 8 || length > 16*1024*1024 { + buf.Reset() + return + } + if buf.Len() < int(length) { + return + } + packet := make([]byte, length) + if _, err := buf.Read(packet); err != nil { + return + } + e.handlePacket(packet) + } +} + +func (e *QueryExtractor) handlePacket(raw []byte) { + if PacketTypeOf(raw) != PacketTypeData { + return + } + d, err := ParseDataPacket(raw, e.use32Bit) + if err != nil { + return + } + if len(d.Payload) < 1 { + return + } + switch e.direction { + case "client->upstream": + e.handleClientRequest(d.Payload) + case "upstream->client": + e.handleServerResponse(d.Payload) + } +} + +func (e *QueryExtractor) handleClientRequest(payload []byte) { + // Clients often piggyback an OCLOSE before the new function call; scan for + // the function-call+opcode marker pair instead of parsing from offset 0. + if idx := findBytePair(payload, ttcMsgFunction, ttcFuncOALL8); idx >= 0 { + r := NewTTCReader(payload[idx+2:]) + if sqlText := tryExtractSQL(r); sqlText != "" { + e.pair.mu.Lock() + e.pair.pending = &pendingQuery{sql: sqlText, timestamp: time.Now()} + e.pair.mu.Unlock() + } + return + } + if findBytePair(payload, ttcMsgFunction, ttcFuncOCOMMIT) >= 0 { + e.recordLiteral("COMMIT") + return + } + if findBytePair(payload, ttcMsgFunction, ttcFuncORLLBK) >= 0 { + e.recordLiteral("ROLLBACK") + return + } +} + +func findBytePair(data []byte, b1, b2 byte) int { + for i := 0; i+1 < len(data); i++ { + if data[i] == b1 && data[i+1] == b2 { + return i + } + } + return -1 +} + +func (e *QueryExtractor) recordLiteral(sql string) { + e.pair.mu.Lock() + e.pair.pending = &pendingQuery{sql: sql, timestamp: time.Now()} + e.pair.mu.Unlock() +} + +// tryExtractSQL uses a longest-printable-run heuristic because OALL8 headers +// vary across client drivers and bind patterns. +func tryExtractSQL(r *TTCReader) string { + remaining := r.Remaining() + if remaining <= 0 { + return "" + } + buf, err := r.GetBytes(remaining) + if err != nil { + return "" + } + return longestPrintableRun(buf) +} + +func longestPrintableRun(data []byte) string { + bestStart, bestLen := 0, 0 + curStart, curLen := 0, 0 + for i, b := range data { + printable := b == '\t' || b == '\n' || b == '\r' || (b >= 0x20 && b <= 0x7E) + if printable { + if curLen == 0 { + curStart = i + } + curLen++ + if curLen > bestLen { + bestLen = curLen + bestStart = curStart + } + } else { + curLen = 0 + } + } + if bestLen < 4 { + return "" + } + return string(data[bestStart : bestStart+bestLen]) +} + +func (e *QueryExtractor) handleServerResponse(payload []byte) { + e.pair.mu.Lock() + pending := e.pair.pending + e.pair.pending = nil + e.pair.mu.Unlock() + if pending == nil { + return + } + output := extractResponseOutcome(payload) + err := e.logger.LogEntry(session.SessionLogEntry{ + Timestamp: pending.timestamp, + Input: pending.sql, + Output: output, + }) + if err != nil { + log.Debug().Err(err).Str("sessionID", e.sessionID).Msg("session log entry dropped") + } +} + +func extractResponseOutcome(payload []byte) string { + r := NewTTCReader(payload) + for r.Remaining() > 0 { + op, err := r.GetByte() + if err != nil { + break + } + if op == 0x04 { + for i := 0; i < 3; i++ { + if _, err := r.GetInt(4, true, true); err != nil { + return "OK" + } + } + code, err := r.GetInt(4, true, true) + if err != nil || code == 0 { + return "OK" + } + return ora(code) + } + } + return "" +} + +func ora(code int) string { + switch code { + case 0: + return "OK" + case 1: + return "ERROR: ORA-00001: unique constraint violated" + case 900: + return "ERROR: ORA-00900: invalid SQL statement" + case 942: + return "ERROR: ORA-00942: table or view does not exist" + case 1017: + return "ERROR: ORA-01017: invalid username/password" + case 28000: + return "ERROR: ORA-28000: the account is locked" + } + return fmt.Sprintf("ERROR: ORA-%05d", code) +} diff --git a/packages/pam/handlers/oracle/tns.go b/packages/pam/handlers/oracle/tns.go new file mode 100644 index 00000000..93d01f54 --- /dev/null +++ b/packages/pam/handlers/oracle/tns.go @@ -0,0 +1,115 @@ +package oracle + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +type PacketType uint8 + +const ( + PacketTypeConnect PacketType = 1 + PacketTypeAccept PacketType = 2 + PacketTypeRefuse PacketType = 4 + PacketTypeRedirect PacketType = 5 + PacketTypeData PacketType = 6 + PacketTypeResend PacketType = 11 + PacketTypeMarker PacketType = 12 +) + +// use32BitLen: 32-bit length framing after ACCEPT (version >= 315), 16-bit before. +func ReadFullPacket(r io.Reader, use32BitLen bool) ([]byte, error) { + head := make([]byte, 8) + if _, err := io.ReadFull(r, head); err != nil { + return nil, err + } + var length uint32 + if use32BitLen { + length = binary.BigEndian.Uint32(head) + } else { + length = uint32(binary.BigEndian.Uint16(head)) + } + if length < 8 { + return nil, fmt.Errorf("invalid TNS packet length: %d", length) + } + if length > 1<<22 { + return nil, fmt.Errorf("TNS packet too large: %d", length) + } + buf := make([]byte, length) + copy(buf, head) + if length > 8 { + if _, err := io.ReadFull(r, buf[8:]); err != nil { + return nil, err + } + } + return buf, nil +} + +func PacketTypeOf(packet []byte) PacketType { + if len(packet) < 5 { + return 0 + } + return PacketType(packet[4]) +} + +type DataPacket struct { + DataFlag uint16 + Payload []byte +} + +func ParseDataPacket(raw []byte, use32BitLen bool) (*DataPacket, error) { + if len(raw) < 10 || PacketType(raw[4]) != PacketTypeData { + return nil, errors.New("not a DATA packet") + } + return &DataPacket{ + DataFlag: binary.BigEndian.Uint16(raw[8:]), + Payload: append([]byte(nil), raw[10:]...), + }, nil +} + +func (d *DataPacket) Bytes(use32BitLen bool) []byte { + length := uint32(10 + len(d.Payload)) + out := make([]byte, length) + if use32BitLen { + binary.BigEndian.PutUint32(out, length) + } else { + binary.BigEndian.PutUint16(out, uint16(length)) + } + out[4] = byte(PacketTypeData) + out[5] = 0 + binary.BigEndian.PutUint16(out[8:], d.DataFlag) + copy(out[10:], d.Payload) + return out +} + +type RefusePacket struct { + UserReason uint8 + SystemReason uint8 + Message string +} + +func (r *RefusePacket) Bytes() []byte { + msg := []byte(r.Message) + length := uint32(12 + len(msg)) + out := make([]byte, length) + binary.BigEndian.PutUint16(out, uint16(length)) + out[4] = byte(PacketTypeRefuse) + out[5] = 0 + out[8] = r.UserReason + out[9] = r.SystemReason + binary.BigEndian.PutUint16(out[10:], uint16(len(msg))) + copy(out[12:], msg) + return out +} + +func WriteRefuseToClient(w io.Writer, message string) error { + pkt := &RefusePacket{ + UserReason: 0, + SystemReason: 0, + Message: message, + } + _, err := w.Write(pkt.Bytes()) + return err +} diff --git a/packages/pam/handlers/oracle/ttc.go b/packages/pam/handlers/oracle/ttc.go new file mode 100644 index 00000000..30461e59 --- /dev/null +++ b/packages/pam/handlers/oracle/ttc.go @@ -0,0 +1,302 @@ +package oracle + +import ( + "bytes" + "encoding/binary" + "errors" + "io" +) + +type TTCBuilder struct { + buf bytes.Buffer + useBigClrChunks bool + clrChunkSize int +} + +func NewTTCBuilder() *TTCBuilder { + return &TTCBuilder{useBigClrChunks: true, clrChunkSize: 0x7FFF} +} + +func (b *TTCBuilder) Bytes() []byte { return b.buf.Bytes() } + +func (b *TTCBuilder) PutBytes(data ...byte) { b.buf.Write(data) } + +func (b *TTCBuilder) PutUint(num uint64, size uint8, bigEndian, compress bool) { + if size == 1 { + b.buf.WriteByte(uint8(num)) + return + } + if compress { + temp := make([]byte, 8) + binary.BigEndian.PutUint64(temp, num) + temp = bytes.TrimLeft(temp, "\x00") + if size > uint8(len(temp)) { + size = uint8(len(temp)) + } + if size == 0 { + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(size) + b.buf.Write(temp) + return + } + temp := make([]byte, size) + if bigEndian { + switch size { + case 2: + binary.BigEndian.PutUint16(temp, uint16(num)) + case 4: + binary.BigEndian.PutUint32(temp, uint32(num)) + case 8: + binary.BigEndian.PutUint64(temp, num) + } + } else { + switch size { + case 2: + binary.LittleEndian.PutUint16(temp, uint16(num)) + case 4: + binary.LittleEndian.PutUint32(temp, uint32(num)) + case 8: + binary.LittleEndian.PutUint64(temp, num) + } + } + b.buf.Write(temp) +} + +func (b *TTCBuilder) PutInt(num int64, size uint8, bigEndian, compress bool) { + if compress { + temp := make([]byte, 8) + binary.BigEndian.PutUint64(temp, uint64(num)) + temp = bytes.TrimLeft(temp, "\x00") + if size > uint8(len(temp)) { + size = uint8(len(temp)) + } + if size == 0 { + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(size) + b.buf.Write(temp[:size]) + return + } + b.PutUint(uint64(num), size, bigEndian, false) +} + +func (b *TTCBuilder) PutClr(data []byte) { + dataLen := len(data) + if dataLen == 0 { + b.buf.WriteByte(0) + return + } + if dataLen > 0xFC { + b.buf.WriteByte(0xFE) + start := 0 + for start < dataLen { + end := start + b.clrChunkSize + if end > dataLen { + end = dataLen + } + chunk := data[start:end] + if b.useBigClrChunks { + b.PutInt(int64(len(chunk)), 4, true, true) + } else { + b.buf.WriteByte(uint8(len(chunk))) + } + b.buf.Write(chunk) + start += b.clrChunkSize + } + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(uint8(dataLen)) + b.buf.Write(data) +} + +func (b *TTCBuilder) PutString(s string) { b.PutClr([]byte(s)) } + +func (b *TTCBuilder) PutKeyVal(key, val []byte, num uint32) { + if len(key) == 0 { + b.buf.WriteByte(0) + } else { + b.PutUint(uint64(len(key)), 4, true, true) + b.PutClr(key) + } + if len(val) == 0 { + b.buf.WriteByte(0) + } else { + b.PutUint(uint64(len(val)), 4, true, true) + b.PutClr(val) + } + b.PutInt(int64(num), 4, true, true) +} + +func (b *TTCBuilder) PutKeyValString(key, val string, num uint32) { + b.PutKeyVal([]byte(key), []byte(val), num) +} + +type TTCReader struct { + buf []byte + pos int + useBigClrChunks bool +} + +func NewTTCReader(payload []byte) *TTCReader { + return &TTCReader{buf: payload, useBigClrChunks: true} +} + +func (r *TTCReader) Remaining() int { return len(r.buf) - r.pos } + +func (r *TTCReader) Pos() int { return r.pos } + +func (r *TTCReader) read(n int) ([]byte, error) { + if r.pos+n > len(r.buf) { + return nil, io.ErrUnexpectedEOF + } + out := r.buf[r.pos : r.pos+n] + r.pos += n + return out, nil +} + +func (r *TTCReader) GetByte() (uint8, error) { + b, err := r.read(1) + if err != nil { + return 0, err + } + return b[0], nil +} + +func (r *TTCReader) PeekByte() (uint8, error) { + if r.pos >= len(r.buf) { + return 0, io.ErrUnexpectedEOF + } + return r.buf[r.pos], nil +} + +func (r *TTCReader) GetBytes(n int) ([]byte, error) { + b, err := r.read(n) + if err != nil { + return nil, err + } + out := make([]byte, len(b)) + copy(out, b) + return out, nil +} + +func (r *TTCReader) GetInt64(size int, compress, bigEndian bool) (int64, error) { + negFlag := false + if compress { + sb, err := r.read(1) + if err != nil { + return 0, err + } + size = int(sb[0]) + if size&0x80 > 0 { + negFlag = true + size = size & 0x7F + } + bigEndian = true + } + if size == 0 { + return 0, nil + } + if size > 8 { + return 0, errors.New("invalid size for GetInt64") + } + rb, err := r.read(size) + if err != nil { + return 0, err + } + temp := make([]byte, 8) + var v int64 + if bigEndian { + copy(temp[8-size:], rb) + v = int64(binary.BigEndian.Uint64(temp)) + } else { + copy(temp[:size], rb) + v = int64(binary.LittleEndian.Uint64(temp)) + } + if negFlag { + v = -v + } + return v, nil +} + +func (r *TTCReader) GetInt(size int, compress, bigEndian bool) (int, error) { + v, err := r.GetInt64(size, compress, bigEndian) + return int(v), err +} + +func (r *TTCReader) GetClr() ([]byte, error) { + nb, err := r.GetByte() + if err != nil { + return nil, err + } + if nb == 0 || nb == 0xFF || nb == 0xFD { + return nil, nil + } + if nb != 0xFE { + out, err := r.read(int(nb)) + if err != nil { + return nil, err + } + ret := make([]byte, len(out)) + copy(ret, out) + return ret, nil + } + var buf bytes.Buffer + for { + var chunkSize int + if r.useBigClrChunks { + chunkSize, err = r.GetInt(4, true, true) + } else { + b, err2 := r.GetByte() + err = err2 + chunkSize = int(b) + } + if err != nil { + return nil, err + } + if chunkSize == 0 { + break + } + chunk, err := r.read(chunkSize) + if err != nil { + return nil, err + } + buf.Write(chunk) + } + return buf.Bytes(), nil +} + +func (r *TTCReader) GetDlc() ([]byte, error) { + length, err := r.GetInt(4, true, true) + if err != nil { + return nil, err + } + if length <= 0 { + _, _ = r.GetClr() + return nil, nil + } + out, err := r.GetClr() + if err != nil { + return nil, err + } + if len(out) > length { + out = out[:length] + } + return out, nil +} + +func (r *TTCReader) GetKeyVal() (key, val []byte, num int, err error) { + key, err = r.GetDlc() + if err != nil { + return + } + val, err = r.GetDlc() + if err != nil { + return + } + num, err = r.GetInt(4, true, true) + return +} diff --git a/packages/pam/local/database-proxy.go b/packages/pam/local/database-proxy.go index 9f51f4e3..f7e607db 100644 --- a/packages/pam/local/database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -10,6 +10,7 @@ import ( "syscall" "time" + "github.com/Infisical/infisical-merge/packages/pam/handlers/oracle" "github.com/Infisical/infisical-merge/packages/pam/session" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" @@ -125,6 +126,9 @@ func StartDatabaseLocalProxy(accessToken string, accessParams PAMAccessParams, p util.PrintfStderr("sqlserver://%s@localhost:%d?database=%s&encrypt=false&trustServerCertificate=true", username, proxy.port, database) case session.ResourceTypeMongodb: util.PrintfStderr("mongodb://localhost:%d/%s?serverSelectionTimeoutMS=15000", proxy.port, database) + case session.ResourceTypeOracle: + util.PrintfStderr("oracle://%s:%s@localhost:%d/%s", accessParams.AccountName, oracle.ProxyPasswordPlaceholder, proxy.port, database) + util.PrintfStderr("\n\nNote: the password shown is a protocol placeholder required by Oracle, not a secret.") default: util.PrintfStderr("localhost:%d", proxy.port) } diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 2995e99b..f37d49be 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -18,6 +18,7 @@ import ( "github.com/Infisical/infisical-merge/packages/pam/handlers/mongodb" "github.com/Infisical/infisical-merge/packages/pam/handlers/mssql" "github.com/Infisical/infisical-merge/packages/pam/handlers/mysql" + "github.com/Infisical/infisical-merge/packages/pam/handlers/oracle" "github.com/Infisical/infisical-merge/packages/pam/handlers/rdp" "github.com/Infisical/infisical-merge/packages/pam/handlers/redis" "github.com/Infisical/infisical-merge/packages/pam/handlers/ssh" @@ -54,6 +55,7 @@ func GetSupportedResourceTypes() []string { session.ResourceTypeKubernetes, session.ResourceTypeRedis, session.ResourceTypeMongodb, + session.ResourceTypeOracle, } // Only advertise RDP when the real bridge is compiled in. A stub // build would otherwise accept RDP session routing and fail every @@ -388,6 +390,24 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("authMethod", credentials.AuthMethod). Msg("Starting Kubernetes PAM proxy") return proxy.HandleConnection(ctx, conn) + case session.ResourceTypeOracle: + oracleConfig := oracle.OracleProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectDatabase: credentials.Database, + EnableTLS: credentials.SSLEnabled, + TLSConfig: tlsConfig, + SessionID: pamConfig.SessionId, + SessionLogger: sessionLogger, + } + proxy := oracle.NewOracleProxy(oracleConfig) + log.Info(). + Str("sessionId", pamConfig.SessionId). + Str("target", oracleConfig.TargetAddr). + Bool("sslEnabled", credentials.SSLEnabled). + Msg("Starting Oracle PAM proxy") + return proxy.HandleConnection(ctx, conn) case session.ResourceTypeMongodb: mongoConfig := mongodb.MongoDBProxyConfig{ Host: credentials.ConnectionString, diff --git a/packages/pam/session/uploader.go b/packages/pam/session/uploader.go index 6f43781c..405cccc6 100644 --- a/packages/pam/session/uploader.go +++ b/packages/pam/session/uploader.go @@ -31,6 +31,7 @@ const ( ResourceTypeSSH = "ssh" ResourceTypeKubernetes = "kubernetes" ResourceTypeMongodb = "mongodb" + ResourceTypeOracle = "oracledb" ResourceTypeWindows = "windows" ) @@ -75,7 +76,7 @@ func NewSessionUploader(httpClient *resty.Client, credentialsManager *Credential func ParseSessionFilename(filename string) (*SessionFileInfo, error) { // Try new format first: pam_session_{sessionID}_{resourceType}_expires_{timestamp}.enc // Build regex pattern using constants - resourceTypePattern := fmt.Sprintf("(%s|%s|%s|%s|%s|%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeRedis, ResourceTypeMysql, ResourceTypeMssql, ResourceTypeKubernetes, ResourceTypeMongodb, ResourceTypeWindows) + resourceTypePattern := fmt.Sprintf("(%s|%s|%s|%s|%s|%s|%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeRedis, ResourceTypeMysql, ResourceTypeMssql, ResourceTypeKubernetes, ResourceTypeMongodb, ResourceTypeOracle, ResourceTypeWindows) newFormatRegex := regexp.MustCompile(fmt.Sprintf(`^pam_session_(.+)_%s_expires_(\d+)\.enc$`, resourceTypePattern)) matches := newFormatRegex.FindStringSubmatch(filename)