diff --git a/main.go b/main.go index 888f2fe..88302d0 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log" + "net" "os" osuser "os/user" "path/filepath" @@ -32,15 +33,17 @@ func main() { // Parse flags var ( - sshUser = flag.String("l", currentUsername(), "SSH username") - sshPort = flag.String("p", "22", "SSH port") - keyPath = flag.String("i", defaultKeyPath(), "SSH private key path") - tsnetDir = flag.String("tsnet-dir", defaultTsnetDir(), "Tailscale state directory") - controlURL = flag.String("control-url", "", "Tailscale control server URL") - verbose = flag.Bool("v", false, "Verbose output") - insecure = flag.Bool("insecure", false, "Skip host key verification (insecure)") - scpMode = flag.Bool("scp", false, "SCP mode: ts-ssh -scp source dest") - showVersion = flag.Bool("version", false, "Show version") + sshUser = flag.String("l", currentUsername(), "SSH username") + sshPort = flag.String("p", "22", "SSH port") + keyPath = flag.String("i", defaultKeyPath(), "SSH private key path") + tsnetDir = flag.String("tsnet-dir", defaultTsnetDir(), "Tailscale state directory") + controlURL = flag.String("control-url", "", "Tailscale control server URL") + verbose = flag.Bool("v", false, "Verbose output") + insecure = flag.Bool("insecure", false, "Skip host key verification (insecure)") + scpMode = flag.Bool("scp", false, "SCP mode: ts-ssh -scp source dest") + showVersion = flag.Bool("version", false, "Show version") + disablePTY = flag.Bool("T", false, "Disable pseudo-terminal allocation") + dynamicForward = flag.String("D", "", "SOCKS5 dynamic port forwarding on [bind_address:]port") ) flag.Usage = usage @@ -85,7 +88,7 @@ func main() { remoteCmd = args[1:] } - if err := runSSH(target, remoteCmd, *sshUser, *sshPort, *keyPath, *tsnetDir, *controlURL, *insecure, *verbose, logger); err != nil { + if err := runSSH(target, remoteCmd, *sshUser, *sshPort, *keyPath, *tsnetDir, *controlURL, *insecure, *disablePTY, *dynamicForward, *verbose, logger); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } @@ -106,7 +109,7 @@ func usage() { } // runSSH handles the SSH connection -func runSSH(target string, remoteCmd []string, defaultUser, defaultPort, keyPath, tsnetDir, controlURL string, insecure, verbose bool, logger *log.Logger) error { +func runSSH(target string, remoteCmd []string, defaultUser, defaultPort, keyPath, tsnetDir, controlURL string, insecure, disablePTY bool, dynamicForward string, verbose bool, logger *log.Logger) error { // Parse target: [user@]host[:port] sshUser, host, port, err := parseSSHTarget(target, defaultUser, defaultPort) if err != nil { @@ -137,12 +140,19 @@ func runSSH(target string, remoteCmd []string, defaultUser, defaultPort, keyPath } defer client.Close() + // Setup dynamic port forwarding if requested + if dynamicForward != "" { + if err := setupDynamicForward(client, dynamicForward, verbose, logger); err != nil { + return fmt.Errorf("failed to setup dynamic forwarding: %w", err) + } + } + // Execute command or start interactive session if len(remoteCmd) > 0 { return execRemoteCommand(client, remoteCmd, logger) } - return interactiveSession(client, logger) + return interactiveSession(client, disablePTY, logger) } // runSCP handles SCP file transfer @@ -366,7 +376,7 @@ func execRemoteCommand(client *ssh.Client, cmd []string, logger *log.Logger) err } // interactiveSession starts an interactive SSH session -func interactiveSession(client *ssh.Client, logger *log.Logger) error { +func interactiveSession(client *ssh.Client, disablePTY bool, logger *log.Logger) error { session, err := client.NewSession() if err != nil { return fmt.Errorf("failed to create session: %w", err) @@ -381,9 +391,9 @@ func interactiveSession(client *ssh.Client, logger *log.Logger) error { session.Stdout = os.Stdout session.Stderr = os.Stderr - // Setup PTY if we're in a terminal + // Setup PTY if we're in a terminal and PTY is not disabled fd := int(os.Stdin.Fd()) - if term.IsTerminal(fd) { + if !disablePTY && term.IsTerminal(fd) { // Get terminal size width, height, err := term.GetSize(fd) if err != nil { @@ -455,3 +465,191 @@ func extractURL(msg string) string { } return msg } + +// setupDynamicForward sets up SOCKS5 dynamic port forwarding +func setupDynamicForward(client *ssh.Client, forwardSpec string, verbose bool, logger *log.Logger) error { + // Parse bind address and port from forwardSpec + // Format can be: "port" or "bind_address:port" + bindAddr := "localhost" + port := forwardSpec + + if strings.Contains(forwardSpec, ":") { + parts := strings.Split(forwardSpec, ":") + if len(parts) != 2 { + return fmt.Errorf("invalid dynamic forward specification: %s", forwardSpec) + } + bindAddr = parts[0] + port = parts[1] + } + + // Validate port + if err := security.ValidatePort(port); err != nil { + return fmt.Errorf("invalid port for dynamic forwarding: %w", err) + } + + listenAddr := net.JoinHostPort(bindAddr, port) + + // Start listening on local port + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", listenAddr, err) + } + + if verbose { + logger.Printf("SOCKS5 dynamic forwarding listening on %s\n", listenAddr) + } + + // Handle incoming SOCKS5 connections in background + go func() { + defer listener.Close() + for { + localConn, err := listener.Accept() + if err != nil { + if verbose { + logger.Printf("Error accepting connection: %v\n", err) + } + return + } + go handleSOCKS5(client, localConn, verbose, logger) + } + }() + + return nil +} + +// handleSOCKS5 handles a SOCKS5 connection +func handleSOCKS5(client *ssh.Client, localConn net.Conn, verbose bool, logger *log.Logger) { + defer localConn.Close() + + // SOCKS5 handshake + buf := make([]byte, 256) + + // Read version and methods + n, err := localConn.Read(buf) + if err != nil || n < 2 { + if verbose { + logger.Printf("SOCKS5 handshake failed: %v\n", err) + } + return + } + + // Check SOCKS version + if buf[0] != 0x05 { + if verbose { + logger.Printf("Not SOCKS5 protocol: version=%d\n", buf[0]) + } + return + } + + // Send "no authentication required" response + if _, err := localConn.Write([]byte{0x05, 0x00}); err != nil { + if verbose { + logger.Printf("Failed to send auth response: %v\n", err) + } + return + } + + // Read connection request + n, err = localConn.Read(buf) + if err != nil || n < 7 { + if verbose { + logger.Printf("Failed to read connection request: %v\n", err) + } + return + } + + // Check version and command + if buf[0] != 0x05 || buf[1] != 0x01 { + if verbose { + logger.Printf("Invalid SOCKS5 request: version=%d, cmd=%d\n", buf[0], buf[1]) + } + // Send failure response + localConn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + return + } + + // Parse address + addrType := buf[3] + var host string + var port uint16 + + switch addrType { + case 0x01: // IPv4 + if n < 10 { + if verbose { + logger.Printf("Invalid IPv4 address length\n") + } + localConn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + return + } + host = fmt.Sprintf("%d.%d.%d.%d", buf[4], buf[5], buf[6], buf[7]) + port = uint16(buf[8])<<8 | uint16(buf[9]) + case 0x03: // Domain name + addrLen := int(buf[4]) + if n < 5+addrLen+2 { + if verbose { + logger.Printf("Invalid domain name length\n") + } + localConn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + return + } + host = string(buf[5 : 5+addrLen]) + port = uint16(buf[5+addrLen])<<8 | uint16(buf[5+addrLen+1]) + case 0x04: // IPv6 + if n < 22 { + if verbose { + logger.Printf("Invalid IPv6 address length\n") + } + localConn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + return + } + host = net.IP(buf[4:20]).String() + port = uint16(buf[20])<<8 | uint16(buf[21]) + default: + if verbose { + logger.Printf("Unsupported address type: %d\n", addrType) + } + localConn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + return + } + + targetAddr := fmt.Sprintf("%s:%d", host, port) + if verbose { + logger.Printf("SOCKS5 forwarding to: %s\n", targetAddr) + } + + // Dial through SSH + remoteConn, err := client.Dial("tcp", targetAddr) + if err != nil { + if verbose { + logger.Printf("Failed to dial %s: %v\n", targetAddr, err) + } + // Send connection refused + localConn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) + return + } + defer remoteConn.Close() + + // Send success response + if _, err := localConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}); err != nil { + if verbose { + logger.Printf("Failed to send success response: %v\n", err) + } + return + } + + // Bidirectional copy + done := make(chan struct{}, 2) + + go func() { + io.Copy(remoteConn, localConn) + done <- struct{}{} + }() + + go func() { + io.Copy(localConn, remoteConn) + done <- struct{}{} + }() + + <-done +}