diff --git a/client/client.go b/client/client.go index 5e1b754..0e1ffad 100644 --- a/client/client.go +++ b/client/client.go @@ -19,8 +19,6 @@ package client import ( - "context" - "crypto/tls" "errors" "fmt" "io" @@ -190,7 +188,7 @@ func (ts *TunnelServer) effectiveMaxQnameLen() int { // either call the step-by-step Initiate* methods (for embedding in frameworks // like xray-core) or call ListenAndServe for a fully managed session. type Tunnel struct { - Resolver Resolver + Resolvers []Resolver TunnelServer TunnelServer // Session configuration. Zero values use defaults. @@ -219,9 +217,9 @@ type Tunnel struct { // NewTunnel creates a Tunnel with the given resolver and server configuration. // Zero-value fields use sensible defaults. -func NewTunnel(resolver Resolver, tunnelServer TunnelServer) (*Tunnel, error) { +func NewTunnel(resolvers []Resolver, tunnelServer TunnelServer) (*Tunnel, error) { t := &Tunnel{ - Resolver: resolver, + Resolvers: resolvers, TunnelServer: tunnelServer, } t.wireConfig = tunnelServer.wireConfig() @@ -293,79 +291,22 @@ func (t *Tunnel) effectiveKCPWindowSize() int { // InitiateResolverConnection creates the underlying transport connection // based on the Resolver configuration. func (t *Tunnel) InitiateResolverConnection() error { - r := t.Resolver - switch r.ResolverType { - case ResolverTypeUDP: - addr, err := net.ResolveUDPAddr("udp", r.ResolverAddr) - if err != nil { - return err - } - t.remoteAddr = addr - if r.UDPSharedSocket { - lc := net.ListenConfig{Control: r.DialerControl} - conn, err := lc.ListenPacket(context.Background(), "udp", ":0") - if err != nil { - return err - } - t.resolverConn = conn - } else { - workers := r.UDPWorkers - if workers <= 0 { - workers = DefaultUDPWorkers - } - timeout := r.UDPTimeout - if timeout <= 0 { - timeout = DefaultUDPResponseTimeout - } - conn, forgedStats, err := NewUDPPacketConn(addr, r.DialerControl, workers, timeout, !r.UDPAcceptErrors, t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode()) - if err != nil { - return err - } - t.forgedStats = forgedStats - t.resolverConn = conn - } - return nil - - case ResolverTypeDOH: - t.remoteAddr = turbotunnel.DummyAddr{} - var rt http.RoundTripper - if r.RoundTripper != nil { - rt = r.RoundTripper - } else if r.UTLSClientHelloID != nil { - rt = NewUTLSRoundTripper(nil, r.UTLSClientHelloID) - } else { - rt = http.DefaultTransport - } - conn, err := NewHTTPPacketConn(rt, r.ResolverAddr, 8, t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode()) + if len(t.Resolvers) > 1 { + conn, err := NewMultiResolver(t.Resolvers, SelectionRoundRobin, t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode()) if err != nil { return err } t.resolverConn = conn - return nil - - case ResolverTypeDOT: t.remoteAddr = turbotunnel.DummyAddr{} - var dialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) - if r.UTLSClientHelloID != nil { - id := r.UTLSClientHelloID - dialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return UTLSDialContext(ctx, network, addr, nil, id) - } - } else { - dialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return tls.DialWithDialer(&net.Dialer{}, network, addr, nil) - } - } - conn, err := NewTLSPacketConn(r.ResolverAddr, dialTLSContext, t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode()) - if err != nil { - return err - } - t.resolverConn = conn return nil - - default: - return fmt.Errorf("unsupported resolver type: %s", r.ResolverType) } + conn, addr, err := GetResolverConnection(t.Resolvers[0], t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode()) + if err != nil { + return err + } + t.resolverConn = conn + t.remoteAddr = addr + return nil } // InitiateDNSPacketConn wraps the resolver connection with DNS encoding. @@ -589,6 +530,24 @@ func (t *Tunnel) closeTransportLayers() { t.forgedStats = nil } +// MultiResolverStats returns per-resolver health and count snapshots when the +// active resolver transport is MultiResolver; otherwise it returns nil. +func (t *Tunnel) MultiResolverStats() []ResolverStat { + if mr, ok := t.resolverConn.(*MultiResolver); ok { + return mr.ResolverStats() + } + return nil +} + +// MultiResolverValidInvalidCounts returns valid/invalid counters per resolver +// address when the active resolver transport is MultiResolver. +func (t *Tunnel) MultiResolverValidInvalidCounts() map[string][2]int64 { + if mr, ok := t.resolverConn.(*MultiResolver); ok { + return mr.ValidInvalidCounts() + } + return nil +} + // resetTransportLayers tears down existing transport layers and creates fresh // ones. Used during reconnect to ensure a clean transport stack. func (t *Tunnel) resetTransportLayers() error { @@ -882,10 +841,10 @@ func NewOutbound(resolvers []Resolver, tunnelServers []TunnelServer) *Outbound { // Start begins accepting connections on bind and forwarding them through the // first resolver/server pair. func (o *Outbound) Start(bind string) error { - resolver := o.Resolvers[0] + tunnelServer := o.TunnelServers[0] - tunnel, err := NewTunnel(resolver, tunnelServer) + tunnel, err := NewTunnel(o.Resolvers, tunnelServer) if err != nil { return fmt.Errorf("failed to create tunnel: %w", err) } diff --git a/client/multi_resolver.go b/client/multi_resolver.go new file mode 100644 index 0000000..c46f791 --- /dev/null +++ b/client/multi_resolver.go @@ -0,0 +1,610 @@ +package client + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/net2share/vaydns/dns" + "github.com/net2share/vaydns/turbotunnel" + log "github.com/sirupsen/logrus" +) + +// SelectionMode controls which resolver MultiResolver picks for outgoing packets. +type SelectionMode string + +const ( + // SelectionRoundRobin rotates through preferred resolvers in order. + SelectionRoundRobin SelectionMode = "roundrobin" + // SelectionBest picks the resolver with the best observed score. + SelectionBest SelectionMode = "best" + // SelectionSmart currently aliases best-score selection. + SelectionSmart SelectionMode = "smart" +) + +// ResolverState is the current health state of one resolver. +type ResolverState string + +const ( + ResolverStateUnknown ResolverState = "unknown" + ResolverStateHealthy ResolverState = "healthy" + ResolverStateRateLimited ResolverState = "rate_limited" + ResolverStateDown ResolverState = "down" +) + +const ( + pendingResponseTimeout = 5 * time.Second + healthTickInterval = 1 * time.Second + downTimeoutThreshold = int64(8) + rateLimitThreshold = int64(5) + probeInterval = 1 * time.Second +) + +// ResolverStat is a snapshot of one resolver's counters and health state. +type ResolverStat struct { + Address string + State ResolverState + ValidCount int64 + InvalidCount int64 + TimeoutCount int64 + LastWrite time.Time + LastValid time.Time +} + +type resolverEntry struct { + name string + addr net.Addr + conn net.PacketConn + + validCount atomic.Int64 + invalidCount atomic.Int64 + timeoutCount atomic.Int64 + + mu sync.Mutex + pending map[uint16]time.Time + lastWrite time.Time + lastValid time.Time + lastProbe time.Time + state ResolverState +} + +func (e *resolverEntry) writePacket(b []byte) (int, error) { + now := time.Now() + e.trackOutgoingID(b, now) + n, err := e.conn.WriteTo(b, e.addr) + if err != nil { + e.invalidCount.Add(1) + } + e.mu.Lock() + e.lastWrite = now + e.mu.Unlock() + return n, err +} + +func (e *resolverEntry) trackOutgoingID(b []byte, now time.Time) { + msg, err := dns.MessageFromWireFormat(b) + if err != nil { + return + } + e.mu.Lock() + e.pending[msg.ID] = now + e.mu.Unlock() +} + +func (e *resolverEntry) readPacket() multiReadResult { + var result multiReadResult + result.entry = e + result.n, result.addr, result.err = e.conn.ReadFrom(result.buf[:]) + if result.err == nil { + e.evaluateIncoming(result.buf[:result.n]) + } + return result +} + +func (e *resolverEntry) evaluateIncoming(packet []byte) { + resp, err := dns.MessageFromWireFormat(packet) + if err != nil { + e.invalidCount.Add(1) + e.recomputeState(time.Now()) + return + } + + e.mu.Lock() + delete(e.pending, resp.ID) + e.mu.Unlock() + + if isValidDNSResponse(resp) { + e.validCount.Add(1) + e.timeoutCount.Store(0) + e.mu.Lock() + e.lastValid = time.Now() + e.state = ResolverStateHealthy + e.mu.Unlock() + return + } + + e.invalidCount.Add(1) + if isRateLimitedResponse(resp) { + e.mu.Lock() + e.state = ResolverStateRateLimited + e.mu.Unlock() + } + e.recomputeState(time.Now()) +} + +func (e *resolverEntry) expirePending(now time.Time) { + expired := int64(0) + e.mu.Lock() + for id, t := range e.pending { + if now.Sub(t) >= pendingResponseTimeout { + delete(e.pending, id) + expired++ + } + } + e.mu.Unlock() + if expired > 0 { + e.timeoutCount.Add(expired) + e.invalidCount.Add(expired) + } + e.recomputeState(now) +} + +func (e *resolverEntry) recomputeState(now time.Time) { + e.mu.Lock() + defer e.mu.Unlock() + + timeouts := e.timeoutCount.Load() + invalid := e.invalidCount.Load() + valid := e.validCount.Load() + + switch { + case timeouts >= downTimeoutThreshold: + e.state = ResolverStateDown + case invalid >= rateLimitThreshold && valid == 0: + e.state = ResolverStateRateLimited + case valid > 0 && now.Sub(e.lastValid) <= 30*time.Second: + e.state = ResolverStateHealthy + case valid == 0: + e.state = ResolverStateUnknown + default: + e.state = ResolverStateUnknown + } + + // Slow decay to avoid sticky penalties. + if invalid > 0 { + e.invalidCount.Store(invalid - 1) + } + if timeouts > 0 { + e.timeoutCount.Store(timeouts - 1) + } +} + +func (e *resolverEntry) stateSnapshot() ResolverState { + e.mu.Lock() + defer e.mu.Unlock() + return e.state +} + +func (e *resolverEntry) markProbe(now time.Time) { + e.mu.Lock() + e.lastProbe = now + e.mu.Unlock() +} + +func (e *resolverEntry) canProbe(now time.Time) bool { + e.mu.Lock() + defer e.mu.Unlock() + return now.Sub(e.lastProbe) >= probeInterval +} + +func (e *resolverEntry) snapshot() ResolverStat { + e.mu.Lock() + defer e.mu.Unlock() + return ResolverStat{ + Address: e.name, + State: e.state, + ValidCount: e.validCount.Load(), + InvalidCount: e.invalidCount.Load(), + TimeoutCount: e.timeoutCount.Load(), + LastWrite: e.lastWrite, + LastValid: e.lastValid, + } +} + +type multiReadResult struct { + buf [4096]byte + n int + addr net.Addr + err error + entry *resolverEntry +} + +// MultiResolver is a net.PacketConn that multiplexes across multiple DNS +// resolver transport connections. It tracks per-resolver health from valid and +// invalid responses, avoids down resolvers for primary traffic, and probes +// unhealthy resolvers by duplicating selected packets. +type MultiResolver struct { + entries []*resolverEntry + mode SelectionMode + mu sync.Mutex + rrIndex int + probeRR int + recvChan chan multiReadResult + closed chan struct{} + closeOnce sync.Once +} + +// NewMultiResolver creates a MultiResolver from a slice of Resolver configs. +func NewMultiResolver(resolvers []Resolver, mode SelectionMode, queueSize int, overflowMode turbotunnel.QueueOverflowMode) (*MultiResolver, error) { + if len(resolvers) == 0 { + return nil, fmt.Errorf("at least one resolver is required") + } + + entries := make([]*resolverEntry, 0, len(resolvers)) + for _, r := range resolvers { + conn, addr, err := GetResolverConnection(r, queueSize, overflowMode) + if err != nil { + for _, e := range entries { + e.conn.Close() + } + return nil, fmt.Errorf("resolver %s %s: %w", r.ResolverType, r.ResolverAddr, err) + } + entries = append(entries, &resolverEntry{ + name: r.ResolverAddr, + addr: addr, + conn: conn, + pending: make(map[uint16]time.Time), + state: ResolverStateUnknown, + }) + } + + mr := &MultiResolver{ + entries: entries, + mode: mode, + recvChan: make(chan multiReadResult, len(entries)*4), + closed: make(chan struct{}), + } + for _, e := range entries { + entry := e + go func() { + for { + res := entry.readPacket() + select { + case mr.recvChan <- res: + case <-mr.closed: + return + } + if res.err != nil { + return + } + } + }() + } + go mr.healthWorker() + return mr, nil +} + +func (mr *MultiResolver) healthWorker() { + ticker := time.NewTicker(healthTickInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + stats := make([]ResolverStat, 0, len(mr.entries)) + for _, e := range mr.entries { + e.expirePending(now) + stats = append(stats, e.snapshot()) + } + log.Trace("\n" + renderResolverStatsTable(stats, now)) + case <-mr.closed: + return + } + } +} + +func renderResolverStatsTable(stats []ResolverStat, now time.Time) string { + var b strings.Builder + b.WriteString("+-------------------------+--------------+--------+---------+---------+-------------+-------------+\n") + b.WriteString("| resolver | state | valid | invalid | timeout | last_write | last_valid |\n") + b.WriteString("+-------------------------+--------------+--------+---------+---------+-------------+-------------+\n") + for _, s := range stats { + lastWriteAgo := "-" + if !s.LastWrite.IsZero() { + lastWriteAgo = now.Sub(s.LastWrite).Truncate(time.Second).String() + } + lastValidAgo := "-" + if !s.LastValid.IsZero() { + lastValidAgo = now.Sub(s.LastValid).Truncate(time.Second).String() + } + b.WriteString(fmt.Sprintf("| %-23.23s | %-12s | %6d | %7d | %7d | %11s | %11s |\n", + s.Address, + s.State, + s.ValidCount, + s.InvalidCount, + s.TimeoutCount, + lastWriteAgo, + lastValidAgo, + )) + } + b.WriteString("+-------------------------+--------------+--------+---------+---------+-------------+-------------+") + return b.String() +} + +func isValidDNSResponse(resp dns.Message) bool { + if resp.Flags&0x8000 == 0 { + return false + } + return (resp.Flags & 0x000f) == dns.RcodeNoError +} + +func isRateLimitedResponse(resp dns.Message) bool { + rcode := resp.Flags & 0x000f + return rcode == dns.RcodeRefused || rcode == dns.RcodeServerFailure +} + +// ReadFrom receives a packet from whichever resolver responds first. +func (mr *MultiResolver) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + select { + case <-mr.closed: + return 0, nil, net.ErrClosed + case res := <-mr.recvChan: + if res.err != nil { + return 0, res.addr, res.err + } + n = copy(b, res.buf[:res.n]) + return n, turbotunnel.DummyAddr{}, nil + } +} + +// WriteTo sends b to the selected primary resolver and may duplicate b to one +// unhealthy resolver as a probe to detect recovery. +func (mr *MultiResolver) WriteTo(b []byte, _ net.Addr) (n int, err error) { + select { + case <-mr.closed: + return 0, net.ErrClosed + default: + } + + primary := mr.selectPrimary() + n, err = primary.writePacket(b) + if err != nil { + return n, err + } + + if probe := mr.selectProbeTarget(primary); probe != nil { + probe.markProbe(time.Now()) + _, _ = probe.writePacket(b) + } + return n, nil +} + +func (mr *MultiResolver) selectPrimary() *resolverEntry { + if mr.mode == SelectionRoundRobin { + if e := mr.selectRoundRobinHealthy(); e != nil { + return e + } + return mr.entries[0] + } + if e := mr.selectBestScore(); e != nil { + return e + } + return mr.entries[0] +} + +func (mr *MultiResolver) selectRoundRobinHealthy() *resolverEntry { + mr.mu.Lock() + defer mr.mu.Unlock() + + if len(mr.entries) == 0 { + return nil + } + + start := mr.rrIndex + for i := 0; i < len(mr.entries); i++ { + idx := (start + i) % len(mr.entries) + state := mr.entries[idx].stateSnapshot() + if state == ResolverStateHealthy || state == ResolverStateUnknown { + mr.rrIndex = (idx + 1) % len(mr.entries) + return mr.entries[idx] + } + } + idx := start % len(mr.entries) + mr.rrIndex = (idx + 1) % len(mr.entries) + return mr.entries[idx] +} + +func (mr *MultiResolver) selectBestScore() *resolverEntry { + if len(mr.entries) == 0 { + return nil + } + best := mr.entries[0] + bestScore := resolverScore(best) + for _, e := range mr.entries[1:] { + s := resolverScore(e) + if s > bestScore { + best = e + bestScore = s + } + } + return best +} + +func resolverScore(e *resolverEntry) int64 { + statePenalty := int64(0) + switch e.stateSnapshot() { + case ResolverStateHealthy: + statePenalty = 0 + case ResolverStateUnknown: + statePenalty = 5 + case ResolverStateRateLimited: + statePenalty = 15 + case ResolverStateDown: + statePenalty = 30 + } + return e.validCount.Load()*4 - e.invalidCount.Load()*2 - e.timeoutCount.Load()*3 - statePenalty +} + +func (mr *MultiResolver) selectProbeTarget(primary *resolverEntry) *resolverEntry { + mr.mu.Lock() + defer mr.mu.Unlock() + + now := time.Now() + for i := 0; i < len(mr.entries); i++ { + idx := (mr.probeRR + i) % len(mr.entries) + e := mr.entries[idx] + if e == primary { + continue + } + state := e.stateSnapshot() + if state == ResolverStateHealthy { + continue + } + if e.canProbe(now) { + mr.probeRR = (idx + 1) % len(mr.entries) + return e + } + } + return nil +} + +// ResolverStats returns current resolver health counters. +func (mr *MultiResolver) ResolverStats() []ResolverStat { + stats := make([]ResolverStat, 0, len(mr.entries)) + for _, e := range mr.entries { + stats = append(stats, e.snapshot()) + } + return stats +} + +// ValidInvalidCounts returns valid/invalid counts by resolver address. +func (mr *MultiResolver) ValidInvalidCounts() map[string][2]int64 { + out := make(map[string][2]int64, len(mr.entries)) + for _, e := range mr.entries { + out[e.name] = [2]int64{e.validCount.Load(), e.invalidCount.Load()} + } + return out +} + +// Close closes all underlying connections and stops the reader goroutines. +func (mr *MultiResolver) Close() error { + mr.closeOnce.Do(func() { + close(mr.closed) + for _, e := range mr.entries { + e.conn.Close() + } + }) + return nil +} + +// LocalAddr returns the local address of the first underlying connection. +func (mr *MultiResolver) LocalAddr() net.Addr { + return mr.entries[0].conn.LocalAddr() +} + +// SetDeadline sets a deadline on all underlying connections. +func (mr *MultiResolver) SetDeadline(t time.Time) error { + var last error + for _, e := range mr.entries { + if err := e.conn.SetDeadline(t); err != nil { + last = err + } + } + return last +} + +// SetReadDeadline sets a read deadline on all underlying connections. +func (mr *MultiResolver) SetReadDeadline(t time.Time) error { + var last error + for _, e := range mr.entries { + if err := e.conn.SetReadDeadline(t); err != nil { + last = err + } + } + return last +} + +// SetWriteDeadline sets a write deadline on all underlying connections. +func (mr *MultiResolver) SetWriteDeadline(t time.Time) error { + var last error + for _, e := range mr.entries { + if err := e.conn.SetWriteDeadline(t); err != nil { + last = err + } + } + return last +} + +// getResolverConnection creates the underlying transport net.PacketConn for r. +func GetResolverConnection(r Resolver, queueSize int, overflowMode turbotunnel.QueueOverflowMode) (net.PacketConn, net.Addr, error) { + switch r.ResolverType { + case ResolverTypeUDP: + addr, err := net.ResolveUDPAddr("udp", r.ResolverAddr) + if err != nil { + return nil, nil, err + } + if r.UDPSharedSocket { + lc := net.ListenConfig{Control: r.DialerControl} + conn, err := lc.ListenPacket(context.Background(), "udp", ":0") + if err != nil { + return nil, nil, err + } + return conn, addr, nil + } + workers := r.UDPWorkers + if workers <= 0 { + workers = DefaultUDPWorkers + } + timeout := r.UDPTimeout + if timeout <= 0 { + timeout = DefaultUDPResponseTimeout + } + conn, _, err := NewUDPPacketConn(addr, r.DialerControl, workers, timeout, !r.UDPAcceptErrors, queueSize, overflowMode) + if err != nil { + return nil, nil, err + } + return conn, addr, nil + + case ResolverTypeDOH: + var rt http.RoundTripper + if r.RoundTripper != nil { + rt = r.RoundTripper + } else if r.UTLSClientHelloID != nil { + rt = NewUTLSRoundTripper(nil, r.UTLSClientHelloID) + } else { + rt = http.DefaultTransport + } + conn, err := NewHTTPPacketConn(rt, r.ResolverAddr, 8, queueSize, overflowMode) + if err != nil { + return nil, nil, err + } + return conn, turbotunnel.DummyAddr{}, nil + + case ResolverTypeDOT: + var dialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + if r.UTLSClientHelloID != nil { + id := r.UTLSClientHelloID + dialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return UTLSDialContext(ctx, network, addr, nil, id) + } + } else { + dialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return tls.DialWithDialer(&net.Dialer{}, network, addr, nil) + } + } + conn, err := NewTLSPacketConn(r.ResolverAddr, dialTLSContext, queueSize, overflowMode) + if err != nil { + return nil, nil, err + } + return conn, turbotunnel.DummyAddr{}, nil + + default: + return nil, nil, fmt.Errorf("unsupported resolver type: %s", r.ResolverType) + } +} diff --git a/dns/dns.go b/dns/dns.go index 8f35a6a..a6d0a62 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -47,8 +47,9 @@ var ( const ( // https://tools.ietf.org/html/rfc1035#section-3.2.2 - RRTypeA = 1 - RRTypeNS = 2 + RRTypeA = 1 + RRTypeNS = 2 + RRTypeCNAME = 5 RRTypeMX = 15 RRTypeTXT = 16 @@ -66,6 +67,7 @@ const ( RcodeServerFailure = 2 // a.k.a. SERVFAIL RcodeNameError = 3 // a.k.a. NXDOMAIN RcodeNotImplemented = 4 // a.k.a. NOTIMPL + RcodeRefused = 5 // a.k.a. REFUSED // https://tools.ietf.org/html/rfc6891#section-9 ExtendedRcodeBadVers = 16 // a.k.a. BADVERS ) diff --git a/vaydns-client/main.go b/vaydns-client/main.go index f6e9ac5..65bb893 100644 --- a/vaydns-client/main.go +++ b/vaydns-client/main.go @@ -20,6 +20,16 @@ import ( log "github.com/sirupsen/logrus" ) +type StringSliceFlag []string + +func (s *StringSliceFlag) String() string { + return fmt.Sprint(*s) +} + +func (s *StringSliceFlag) Set(value string) error { + *s = append(*s, value) + return nil +} func readKeyFromFile(filename string) ([]byte, error) { f, err := os.Open(filename) if err != nil { @@ -30,13 +40,13 @@ func readKeyFromFile(filename string) ([]byte, error) { } func main() { - var dohURL string - var dotAddr string + var dohURLs StringSliceFlag + var dotAddrs StringSliceFlag var domainArg string var listenAddr string var pubkeyFilename string var pubkeyString string - var udpAddr string + var udpAddrs StringSliceFlag var utlsDistribution string var maxQnameLen int var maxNumLabels int @@ -91,11 +101,11 @@ Known TLS fingerprints for -utls are: fmt.Fprintln(flag.CommandLine.Output(), line.String()) } } - flag.StringVar(&dohURL, "doh", "", "URL of DoH resolver") - flag.StringVar(&dotAddr, "dot", "", "address of DoT resolver") + flag.Var(&dohURLs, "doh", "URL of DoH resolver") + flag.Var(&dotAddrs, "dot", "address of DoT resolver") flag.StringVar(&pubkeyString, "pubkey", "", fmt.Sprintf("server public key (%d hex digits)", noise.KeyLen*2)) flag.StringVar(&pubkeyFilename, "pubkey-file", "", "read server public key from file") - flag.StringVar(&udpAddr, "udp", "", "address of UDP DNS resolver") + flag.Var(&udpAddrs, "udp", "address of UDP DNS resolver") flag.StringVar(&utlsDistribution, "utls", "4*random,3*Firefox_120,1*Firefox_105,3*Chrome_120,1*Chrome_102,1*iOS_14,1*iOS_13", "choose TLS fingerprint from weighted distribution") @@ -185,32 +195,42 @@ Known TLS fingerprints for -utls are: if utlsClientHelloID != nil { log.Infof("uTLS fingerprint %s %s", utlsClientHelloID.Client, utlsClientHelloID.Version) } - - // Select resolver transport. - var resolverType client.ResolverType - var resolverAddr string - transportCount := 0 - if dohURL != "" { - resolverType = client.ResolverTypeDOH - resolverAddr = dohURL - transportCount++ - } - if dotAddr != "" { - resolverType = client.ResolverTypeDOT - resolverAddr = dotAddr - transportCount++ - } - if udpAddr != "" { - resolverType = client.ResolverTypeUDP - resolverAddr = udpAddr - transportCount++ - } - if transportCount == 0 { - fmt.Fprintf(os.Stderr, "one of -doh, -dot, or -udp is required\n") + udpTimeout, err := time.ParseDuration(udpTimeoutStr) + if err != nil { + fmt.Fprintf(os.Stderr, "invalid -udp-timeout: %v\n", err) os.Exit(1) } - if transportCount > 1 { - fmt.Fprintf(os.Stderr, "only one of -doh, -dot, and -udp may be given\n") + // Select resolver transport. + resolvers := make([]client.Resolver, 0) + + for _, dohURL := range dohURLs { + resolver := client.Resolver{ + ResolverType: client.ResolverTypeDOH, + ResolverAddr: dohURL, + } + resolver.UTLSClientHelloID = utlsClientHelloID + resolvers = append(resolvers, resolver) + } + for _, dotAddr := range dotAddrs { + resolvers = append(resolvers, client.Resolver{ + ResolverType: client.ResolverTypeDOT, + ResolverAddr: dotAddr, + }) + } + for _, udpAddr := range udpAddrs { + resolver := client.Resolver{ + ResolverType: client.ResolverTypeUDP, + ResolverAddr: udpAddr, + } + resolver.UDPWorkers = udpWorkers + resolver.UDPSharedSocket = udpSharedSocket + resolver.UDPTimeout = udpTimeout + resolver.UDPAcceptErrors = udpAcceptErrors + resolvers = append(resolvers, resolver) + } + + if len(resolvers) == 0 { + fmt.Fprintf(os.Stderr, "one of -doh, -dot, or -udp is required\n") os.Exit(1) } @@ -245,11 +265,6 @@ Known TLS fingerprints for -utls are: fmt.Fprintf(os.Stderr, "invalid -open-stream-timeout: %v\n", err) os.Exit(1) } - udpTimeout, err := time.ParseDuration(udpTimeoutStr) - if err != nil { - fmt.Fprintf(os.Stderr, "invalid -udp-timeout: %v\n", err) - os.Exit(1) - } // Validate. if keepAlive >= idleTimeout { @@ -323,16 +338,7 @@ Known TLS fingerprints for -utls are: } // Build resolver. - resolver, err := client.NewResolver(resolverType, resolverAddr) - if err != nil { - fmt.Fprintf(os.Stderr, "resolver: %v\n", err) - os.Exit(1) - } - resolver.UTLSClientHelloID = utlsClientHelloID - resolver.UDPWorkers = udpWorkers - resolver.UDPSharedSocket = udpSharedSocket - resolver.UDPTimeout = udpTimeout - resolver.UDPAcceptErrors = udpAcceptErrors + if udpAcceptErrors { if udpSharedSocket { log.Warnf("-udp-accept-errors has no effect when -udp-shared-socket is set") @@ -355,7 +361,7 @@ Known TLS fingerprints for -utls are: ts.RecordType = recordTypeStr // Build tunnel. - tunnel, err := client.NewTunnel(resolver, ts) + tunnel, err := client.NewTunnel(resolvers, ts) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) os.Exit(1)