Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 32 additions & 73 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
package client

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -190,7 +188,7 @@ func (ts *TunnelServer) effectiveMaxQnameLen() int {
// either call the step-by-step Initiate* methods (for embedding in frameworks
// like xray-core) or call ListenAndServe for a fully managed session.
type Tunnel struct {
Resolver Resolver
Resolvers []Resolver
TunnelServer TunnelServer

// Session configuration. Zero values use defaults.
Expand Down Expand Up @@ -219,9 +217,9 @@ type Tunnel struct {

// NewTunnel creates a Tunnel with the given resolver and server configuration.
// Zero-value fields use sensible defaults.
func NewTunnel(resolver Resolver, tunnelServer TunnelServer) (*Tunnel, error) {
func NewTunnel(resolvers []Resolver, tunnelServer TunnelServer) (*Tunnel, error) {
t := &Tunnel{
Resolver: resolver,
Resolvers: resolvers,
TunnelServer: tunnelServer,
}
t.wireConfig = tunnelServer.wireConfig()
Expand Down Expand Up @@ -293,79 +291,22 @@ func (t *Tunnel) effectiveKCPWindowSize() int {
// InitiateResolverConnection creates the underlying transport connection
// based on the Resolver configuration.
func (t *Tunnel) InitiateResolverConnection() error {
r := t.Resolver
switch r.ResolverType {
case ResolverTypeUDP:
addr, err := net.ResolveUDPAddr("udp", r.ResolverAddr)
if err != nil {
return err
}
t.remoteAddr = addr
if r.UDPSharedSocket {
lc := net.ListenConfig{Control: r.DialerControl}
conn, err := lc.ListenPacket(context.Background(), "udp", ":0")
if err != nil {
return err
}
t.resolverConn = conn
} else {
workers := r.UDPWorkers
if workers <= 0 {
workers = DefaultUDPWorkers
}
timeout := r.UDPTimeout
if timeout <= 0 {
timeout = DefaultUDPResponseTimeout
}
conn, forgedStats, err := NewUDPPacketConn(addr, r.DialerControl, workers, timeout, !r.UDPAcceptErrors, t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode())
if err != nil {
return err
}
t.forgedStats = forgedStats
t.resolverConn = conn
}
return nil

case ResolverTypeDOH:
t.remoteAddr = turbotunnel.DummyAddr{}
var rt http.RoundTripper
if r.RoundTripper != nil {
rt = r.RoundTripper
} else if r.UTLSClientHelloID != nil {
rt = NewUTLSRoundTripper(nil, r.UTLSClientHelloID)
} else {
rt = http.DefaultTransport
}
conn, err := NewHTTPPacketConn(rt, r.ResolverAddr, 8, t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode())
if len(t.Resolvers) > 1 {
conn, err := NewMultiResolver(t.Resolvers, SelectionRoundRobin, t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode())
if err != nil {
return err
}
t.resolverConn = conn
return nil

case ResolverTypeDOT:
t.remoteAddr = turbotunnel.DummyAddr{}
var dialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
if r.UTLSClientHelloID != nil {
id := r.UTLSClientHelloID
dialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return UTLSDialContext(ctx, network, addr, nil, id)
}
} else {
dialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return tls.DialWithDialer(&net.Dialer{}, network, addr, nil)
}
}
conn, err := NewTLSPacketConn(r.ResolverAddr, dialTLSContext, t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode())
if err != nil {
return err
}
t.resolverConn = conn
return nil

default:
return fmt.Errorf("unsupported resolver type: %s", r.ResolverType)
}
conn, addr, err := GetResolverConnection(t.Resolvers[0], t.effectivePacketQueueSize(), t.effectiveQueueOverflowMode())
if err != nil {
return err
}
t.resolverConn = conn
t.remoteAddr = addr
return nil
}

// InitiateDNSPacketConn wraps the resolver connection with DNS encoding.
Expand Down Expand Up @@ -589,6 +530,24 @@ func (t *Tunnel) closeTransportLayers() {
t.forgedStats = nil
}

// MultiResolverStats returns per-resolver health and count snapshots when the
// active resolver transport is MultiResolver; otherwise it returns nil.
func (t *Tunnel) MultiResolverStats() []ResolverStat {
if mr, ok := t.resolverConn.(*MultiResolver); ok {
return mr.ResolverStats()
}
return nil
}

// MultiResolverValidInvalidCounts returns valid/invalid counters per resolver
// address when the active resolver transport is MultiResolver.
func (t *Tunnel) MultiResolverValidInvalidCounts() map[string][2]int64 {
if mr, ok := t.resolverConn.(*MultiResolver); ok {
return mr.ValidInvalidCounts()
}
return nil
}

// resetTransportLayers tears down existing transport layers and creates fresh
// ones. Used during reconnect to ensure a clean transport stack.
func (t *Tunnel) resetTransportLayers() error {
Expand Down Expand Up @@ -882,10 +841,10 @@ func NewOutbound(resolvers []Resolver, tunnelServers []TunnelServer) *Outbound {
// Start begins accepting connections on bind and forwarding them through the
// first resolver/server pair.
func (o *Outbound) Start(bind string) error {
resolver := o.Resolvers[0]

tunnelServer := o.TunnelServers[0]

tunnel, err := NewTunnel(resolver, tunnelServer)
tunnel, err := NewTunnel(o.Resolvers, tunnelServer)
if err != nil {
return fmt.Errorf("failed to create tunnel: %w", err)
}
Expand Down
Loading