Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 213 additions & 15 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"log"
"net"
"os"
osuser "os/user"
"path/filepath"
Expand All @@ -32,15 +33,17 @@ func main() {

// Parse flags
var (
sshUser = flag.String("l", currentUsername(), "SSH username")
sshPort = flag.String("p", "22", "SSH port")
keyPath = flag.String("i", defaultKeyPath(), "SSH private key path")
tsnetDir = flag.String("tsnet-dir", defaultTsnetDir(), "Tailscale state directory")
controlURL = flag.String("control-url", "", "Tailscale control server URL")
verbose = flag.Bool("v", false, "Verbose output")
insecure = flag.Bool("insecure", false, "Skip host key verification (insecure)")
scpMode = flag.Bool("scp", false, "SCP mode: ts-ssh -scp source dest")
showVersion = flag.Bool("version", false, "Show version")
sshUser = flag.String("l", currentUsername(), "SSH username")
sshPort = flag.String("p", "22", "SSH port")
keyPath = flag.String("i", defaultKeyPath(), "SSH private key path")
tsnetDir = flag.String("tsnet-dir", defaultTsnetDir(), "Tailscale state directory")
controlURL = flag.String("control-url", "", "Tailscale control server URL")
verbose = flag.Bool("v", false, "Verbose output")
insecure = flag.Bool("insecure", false, "Skip host key verification (insecure)")
scpMode = flag.Bool("scp", false, "SCP mode: ts-ssh -scp source dest")
showVersion = flag.Bool("version", false, "Show version")
disablePTY = flag.Bool("T", false, "Disable pseudo-terminal allocation")
dynamicForward = flag.String("D", "", "SOCKS5 dynamic port forwarding on [bind_address:]port")
)

flag.Usage = usage
Expand Down Expand Up @@ -85,7 +88,7 @@ func main() {
remoteCmd = args[1:]
}

if err := runSSH(target, remoteCmd, *sshUser, *sshPort, *keyPath, *tsnetDir, *controlURL, *insecure, *verbose, logger); err != nil {
if err := runSSH(target, remoteCmd, *sshUser, *sshPort, *keyPath, *tsnetDir, *controlURL, *insecure, *disablePTY, *dynamicForward, *verbose, logger); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
Expand All @@ -106,7 +109,7 @@ func usage() {
}

// runSSH handles the SSH connection
func runSSH(target string, remoteCmd []string, defaultUser, defaultPort, keyPath, tsnetDir, controlURL string, insecure, verbose bool, logger *log.Logger) error {
func runSSH(target string, remoteCmd []string, defaultUser, defaultPort, keyPath, tsnetDir, controlURL string, insecure, disablePTY bool, dynamicForward string, verbose bool, logger *log.Logger) error {
// Parse target: [user@]host[:port]
sshUser, host, port, err := parseSSHTarget(target, defaultUser, defaultPort)
if err != nil {
Expand Down Expand Up @@ -137,12 +140,19 @@ func runSSH(target string, remoteCmd []string, defaultUser, defaultPort, keyPath
}
defer client.Close()

// Setup dynamic port forwarding if requested
if dynamicForward != "" {
if err := setupDynamicForward(client, dynamicForward, verbose, logger); err != nil {
return fmt.Errorf("failed to setup dynamic forwarding: %w", err)
}
}

// Execute command or start interactive session
if len(remoteCmd) > 0 {
return execRemoteCommand(client, remoteCmd, logger)
}

return interactiveSession(client, logger)
return interactiveSession(client, disablePTY, logger)
}

// runSCP handles SCP file transfer
Expand Down Expand Up @@ -366,7 +376,7 @@ func execRemoteCommand(client *ssh.Client, cmd []string, logger *log.Logger) err
}

// interactiveSession starts an interactive SSH session
func interactiveSession(client *ssh.Client, logger *log.Logger) error {
func interactiveSession(client *ssh.Client, disablePTY bool, logger *log.Logger) error {
session, err := client.NewSession()
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
Expand All @@ -381,9 +391,9 @@ func interactiveSession(client *ssh.Client, logger *log.Logger) error {
session.Stdout = os.Stdout
session.Stderr = os.Stderr

// Setup PTY if we're in a terminal
// Setup PTY if we're in a terminal and PTY is not disabled
fd := int(os.Stdin.Fd())
if term.IsTerminal(fd) {
if !disablePTY && term.IsTerminal(fd) {
// Get terminal size
width, height, err := term.GetSize(fd)
if err != nil {
Expand Down Expand Up @@ -455,3 +465,191 @@ func extractURL(msg string) string {
}
return msg
}

// setupDynamicForward sets up SOCKS5 dynamic port forwarding
func setupDynamicForward(client *ssh.Client, forwardSpec string, verbose bool, logger *log.Logger) error {
// Parse bind address and port from forwardSpec
// Format can be: "port" or "bind_address:port"
bindAddr := "localhost"
port := forwardSpec

if strings.Contains(forwardSpec, ":") {
parts := strings.Split(forwardSpec, ":")
if len(parts) != 2 {
return fmt.Errorf("invalid dynamic forward specification: %s", forwardSpec)
}
bindAddr = parts[0]
port = parts[1]
}

// Validate port
if err := security.ValidatePort(port); err != nil {
return fmt.Errorf("invalid port for dynamic forwarding: %w", err)
}

listenAddr := net.JoinHostPort(bindAddr, port)

// Start listening on local port
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", listenAddr, err)
}

if verbose {
logger.Printf("SOCKS5 dynamic forwarding listening on %s\n", listenAddr)
}

// Handle incoming SOCKS5 connections in background
go func() {
defer listener.Close()
for {
localConn, err := listener.Accept()
if err != nil {
if verbose {
logger.Printf("Error accepting connection: %v\n", err)
}
return
}
go handleSOCKS5(client, localConn, verbose, logger)
}
}()

return nil
}

// handleSOCKS5 handles a SOCKS5 connection
func handleSOCKS5(client *ssh.Client, localConn net.Conn, verbose bool, logger *log.Logger) {
defer localConn.Close()

// SOCKS5 handshake
buf := make([]byte, 256)

// Read version and methods
n, err := localConn.Read(buf)
if err != nil || n < 2 {
if verbose {
logger.Printf("SOCKS5 handshake failed: %v\n", err)
}
return
}

// Check SOCKS version
if buf[0] != 0x05 {
if verbose {
logger.Printf("Not SOCKS5 protocol: version=%d\n", buf[0])
}
return
}

// Send "no authentication required" response
if _, err := localConn.Write([]byte{0x05, 0x00}); err != nil {
if verbose {
logger.Printf("Failed to send auth response: %v\n", err)
}
return
}

// Read connection request
n, err = localConn.Read(buf)
if err != nil || n < 7 {
if verbose {
logger.Printf("Failed to read connection request: %v\n", err)
}
return
}

// Check version and command
if buf[0] != 0x05 || buf[1] != 0x01 {
if verbose {
logger.Printf("Invalid SOCKS5 request: version=%d, cmd=%d\n", buf[0], buf[1])
}
// Send failure response
localConn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
return
}

// Parse address
addrType := buf[3]
var host string
var port uint16

switch addrType {
case 0x01: // IPv4
if n < 10 {
if verbose {
logger.Printf("Invalid IPv4 address length\n")
}
localConn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
return
}
host = fmt.Sprintf("%d.%d.%d.%d", buf[4], buf[5], buf[6], buf[7])
port = uint16(buf[8])<<8 | uint16(buf[9])
case 0x03: // Domain name
addrLen := int(buf[4])
if n < 5+addrLen+2 {
if verbose {
logger.Printf("Invalid domain name length\n")
}
localConn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
return
}
host = string(buf[5 : 5+addrLen])
port = uint16(buf[5+addrLen])<<8 | uint16(buf[5+addrLen+1])
case 0x04: // IPv6
if n < 22 {
if verbose {
logger.Printf("Invalid IPv6 address length\n")
}
localConn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
return
}
host = net.IP(buf[4:20]).String()
port = uint16(buf[20])<<8 | uint16(buf[21])
default:
if verbose {
logger.Printf("Unsupported address type: %d\n", addrType)
}
localConn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
return
}

targetAddr := fmt.Sprintf("%s:%d", host, port)
if verbose {
logger.Printf("SOCKS5 forwarding to: %s\n", targetAddr)
}

// Dial through SSH
remoteConn, err := client.Dial("tcp", targetAddr)
if err != nil {
if verbose {
logger.Printf("Failed to dial %s: %v\n", targetAddr, err)
}
// Send connection refused
localConn.Write([]byte{0x05, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
return
}
defer remoteConn.Close()

// Send success response
if _, err := localConn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}); err != nil {
if verbose {
logger.Printf("Failed to send success response: %v\n", err)
}
return
}

// Bidirectional copy
done := make(chan struct{}, 2)

go func() {
io.Copy(remoteConn, localConn)
done <- struct{}{}
}()

go func() {
io.Copy(localConn, remoteConn)
done <- struct{}{}
}()

<-done
}