Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ require (
github.com/gorilla/websocket v1.5.3 // indirect
github.com/jstemmer/go-junit-report/v2 v2.1.0 // indirect
github.com/kisielk/errcheck v1.9.0 // indirect
github.com/pires/go-proxyproto v0.8.1 // indirect
github.com/openai/openai-go/v3 v3.17.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ github.com/openai/openai-go/v3 v3.17.0 h1:CfTkmQoItolSyW+bHOUF190KuX5+1Zv6MC0Gb4
github.com/openai/openai-go/v3 v3.17.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0=
github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
Expand Down
34 changes: 30 additions & 4 deletions proxy/servertcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/syncutil"
"github.com/miekg/dns"

proxyproto "github.com/pires/go-proxyproto"
)

// initTCPListeners initializes TCP listeners with configured addresses.
func (p *Proxy) initTCPListeners(ctx context.Context) (err error) {
for _, addr := range p.TCPListenAddr {
var ln *net.TCPListener
var ln net.Listener
ln, err = p.listenTCP(ctx, addr)
if err != nil {
return fmt.Errorf("listening on tcp addr %s: %w", addr, err)
Expand All @@ -34,7 +36,7 @@ func (p *Proxy) initTCPListeners(ctx context.Context) (err error) {
}

// listenTCP returns a new TCP listener listening on addr.
func (p *Proxy) listenTCP(ctx context.Context, addr *net.TCPAddr) (ln *net.TCPListener, err error) {
func (p *Proxy) listenTCP(ctx context.Context, addr *net.TCPAddr) (ln net.Listener, err error) {
addrStr := addr.String()
p.logger.InfoContext(ctx, "creating tcp server socket", "addr", addrStr)

Expand All @@ -60,7 +62,29 @@ func (p *Proxy) listenTCP(ctx context.Context, addr *net.TCPAddr) (ln *net.TCPLi

p.logger.InfoContext(ctx, "listening to tcp", "addr", ln.Addr())

return ln, nil
return p.wrapProxyListener(ln), nil
}

// wrapProxyListener wraps a net.Listener with a proxyproto.Listener that
// implements the ConnPolicy callback. If the upstream address is in
// p.TrustedProxies, it returns proxyproto.USE; otherwise, it returns
// proxyproto.REJECT.
func (p *Proxy) wrapProxyListener(ln net.Listener) net.Listener {
Comment thread
peterverraedt marked this conversation as resolved.
return &proxyproto.Listener{
Listener: ln,
ConnPolicy: func(options proxyproto.ConnPolicyOptions) (proxyproto.Policy, error) {
if p.TrustedProxies != nil && p.TrustedProxies.Contains(netutil.NetAddrToAddrPort(options.Upstream).Addr()) {
// If a proxyproto header is present, use it to determine the
// upstream address.
return proxyproto.USE, nil
}

// Reject connections if the proxyproto header is present,
// with reason (will be logged):
// proxyproto: upstream connection sent PROXY header but isn't allowed to send one
return proxyproto.REJECT, nil
},
}
}

// initTLSListeners initializes TLS listeners with configured addresses.
Expand All @@ -78,7 +102,9 @@ func (p *Proxy) initTLSListeners(ctx context.Context) (err error) {
return fmt.Errorf("listening on tls addr %s: %w", addr, err)
}

l := tls.NewListener(tcpListen, p.TLSConfig)
proxyListen := p.wrapProxyListener(tcpListen)

l := tls.NewListener(proxyListen, p.TLSConfig)
p.tlsListen = append(p.tlsListen, l)

p.logger.InfoContext(ctx, "listening to tls", "addr", l.Addr())
Expand Down
42 changes: 42 additions & 0 deletions proxy/serverudp.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package proxy

import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
Expand All @@ -14,6 +17,7 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/syncutil"
"github.com/miekg/dns"
proxyproto "github.com/pires/go-proxyproto"
)

// initUDPListeners initializes UDP listeners with configured addresses.
Expand Down Expand Up @@ -144,6 +148,8 @@ func (p *Proxy) udpHandlePacket(

p.logger.DebugContext(ctx, "handling new udp packet", "raddr", remoteAddr)

packet, remoteAddr = p.parseUDPProxyHeader(packet, remoteAddr)

req := &dns.Msg{}
err := req.Unpack(packet)
if err != nil {
Expand All @@ -162,6 +168,42 @@ func (p *Proxy) udpHandlePacket(
}
}

// parseUDPProxyHeader attempts to parse a proxy protocol header from a UDP
// packet. If the remote address is in p.TrustedProxies and a valid proxy
// protocol header is present, it returns the remaining packet data and the
// source address from the header. Otherwise, it returns the original packet
// and remote address unchanged.
func (p *Proxy) parseUDPProxyHeader(packet []byte, remoteAddr *net.UDPAddr) ([]byte, *net.UDPAddr) {
if p.TrustedProxies == nil || !p.TrustedProxies.Contains(netutil.NetAddrToAddrPort(remoteAddr).Addr()) {
return packet, remoteAddr
}

reader := bufio.NewReader(bytes.NewReader(packet))
header, err := proxyproto.Read(reader)
if err != nil {
// No proxy protocol header found; return packet as-is.
return packet, remoteAddr
}

// Read the remaining bytes after the proxy protocol header; these are the
// actual DNS payload.
remaining, err := io.ReadAll(reader)
if err != nil {
p.logger.Error("reading remaining udp data after proxy header", slogutil.KeyError, err)

return packet, remoteAddr
}

srcUDPAddr, ok := header.SourceAddr.(*net.UDPAddr)
if ok {
return remaining, srcUDPAddr
}

p.logger.Debug("proxy protocol header has unsupported source address type", "addr", header.SourceAddr)

return remaining, remoteAddr
}

// Writes a response to the UDP client
func (p *Proxy) respondUDP(d *DNSContext) error {
resp := d.Res
Expand Down