Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
Expand Down Expand Up @@ -62,6 +63,8 @@ type server struct {
instanceProxyTransport http.RoundTripper
nameGeneratorFactory func(context.Context, string, string) (func() string, error)
activityRecorder func(context.Context, string, string, time.Time) error
findRunningPodFunc func(context.Context, string, string, string) (*corev1.Pod, error)
openSSHPortForwardFunc func(context.Context, *corev1.Pod, uint32) (net.Conn, io.Closer, error)
}

func main() {
Expand Down
242 changes: 235 additions & 7 deletions api/ssh_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,46 @@ package main

import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"

sshserver "github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
corev1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes/scheme"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/tools/remotecommand"
transportspdy "k8s.io/client-go/transport/spdy"

spritzv1 "spritz.sh/operator/api/v1"
)

const sshPrincipalDelimiter = ":"

type sshPortForwardActivityStartedKey struct{}

type sshLocalForwardChannelData struct {
DestAddr string
DestPort uint32

OriginAddr string
OriginPort uint32
}

type closeFunc func() error

func (f closeFunc) Close() error {
return f()
}

func formatSSHPrincipal(prefix, namespace, name string) string {
return strings.Join([]string{prefix, namespace, name}, sshPrincipalDelimiter)
}
Expand All @@ -43,12 +66,7 @@ func (s *server) startSSHGateway(ctx context.Context) error {
return nil
}

server := &sshserver.Server{
Addr: cfg.listenAddr,
Handler: s.handleSSHSession,
PublicKeyHandler: s.handleSSHAuth,
Version: "spritz",
}
server := s.newSSHGatewayServer()
server.AddHostKey(cfg.hostSigner)

errCh := make(chan error, 1)
Expand All @@ -70,6 +88,21 @@ func (s *server) startSSHGateway(ctx context.Context) error {
}
}

func (s *server) newSSHGatewayServer() *sshserver.Server {
cfg := s.sshGateway
return &sshserver.Server{
Addr: cfg.listenAddr,
Handler: s.handleSSHSession,
PublicKeyHandler: s.handleSSHAuth,
Version: "spritz",
ChannelHandlers: map[string]sshserver.ChannelHandler{
"session": sshserver.DefaultSessionHandler,
"direct-tcpip": s.handleSSHPortForward,
},
LocalPortForwardingCallback: s.allowSSHPortForwardDestination,
}
}

func (s *server) handleSSHAuth(ctx sshserver.Context, key sshserver.PublicKey) bool {
cert, ok := key.(*gossh.Certificate)
if !ok {
Expand Down Expand Up @@ -114,7 +147,7 @@ func (s *server) handleSSHSession(sess sshserver.Session) {
_ = sess.Exit(1)
return
}
s.startSSHActivityLoop(sess.Context(), spritz)
s.ensureSSHActivityLoop(sess.Context(), spritz)

pty, winCh, hasPty := sess.Pty()
sizeQueue := newTerminalSizeQueue()
Expand All @@ -135,6 +168,201 @@ func (s *server) handleSSHSession(sess sshserver.Session) {
_ = sess.Exit(0)
}

func (s *server) handleSSHPortForward(srv *sshserver.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx sshserver.Context) {
var request sshLocalForwardChannelData
if err := gossh.Unmarshal(newChan.ExtraData(), &request); err != nil {
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
return
}
if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, request.DestAddr, request.DestPort) {
newChan.Reject(gossh.Prohibited, "port forwarding is disabled")
return
}

namespace, name, ok := parseSSHPrincipal(s.sshGateway.principalPrefix, ctx.User())
if !ok {
log.Printf("spritz ssh: invalid forward principal value=%s", ctx.User())
newChan.Reject(gossh.Prohibited, "invalid ssh principal")
return
}

spritz := &spritzv1.Spritz{}
if err := s.client.Get(ctx, clientKey(namespace, name), spritz); err != nil {
log.Printf("spritz ssh: forward spritz not found name=%s namespace=%s err=%v", name, namespace, err)
newChan.Reject(gossh.ConnectionFailed, "spritz not ready")
return
}

pod, err := s.findSSHGatewayPod(ctx, namespace, name, s.sshGateway.containerName)
if err != nil {
log.Printf("spritz ssh: forward pod not ready name=%s namespace=%s err=%v", name, namespace, err)
newChan.Reject(gossh.ConnectionFailed, "spritz not ready")
return
}
s.ensureSSHActivityLoop(ctx, spritz)

upstream, cleanup, err := s.openSSHPortForward(ctx, pod, request.DestPort)
if err != nil {
log.Printf("spritz ssh: forward open failed name=%s namespace=%s port=%d err=%v", name, namespace, request.DestPort, err)
newChan.Reject(gossh.ConnectionFailed, "port forward unavailable")
return
}

channel, requests, err := newChan.Accept()
if err != nil {
_ = upstream.Close()
_ = cleanup.Close()
return
}
go gossh.DiscardRequests(requests)

var once sync.Once
closeAll := func() {
once.Do(func() {
_ = channel.Close()
_ = upstream.Close()
_ = cleanup.Close()
})
}

go func() {
defer closeAll()
_, _ = io.Copy(channel, upstream)
}()
go func() {
defer closeAll()
_, _ = io.Copy(upstream, channel)
}()
}

func (s *server) allowSSHPortForwardDestination(ctx sshserver.Context, destinationHost string, destinationPort uint32) bool {
if !isLoopbackSSHForwardHost(destinationHost) {
log.Printf("spritz ssh: rejected forward user=%s host=%s port=%d", ctx.User(), destinationHost, destinationPort)
return false
}
return true
}

func isLoopbackSSHForwardHost(host string) bool {
normalized := strings.TrimSpace(host)
normalized = strings.TrimPrefix(normalized, "[")
normalized = strings.TrimSuffix(normalized, "]")
if normalized == "" {
return false
}
if strings.EqualFold(normalized, "localhost") {
return true
}
ip := net.ParseIP(normalized)
return ip != nil && ip.IsLoopback()
}

func (s *server) ensureSSHActivityLoop(ctx sshserver.Context, spritz *spritzv1.Spritz) {
if s == nil || spritz == nil {
return
}
ctx.Lock()
defer ctx.Unlock()
if started, ok := ctx.Value(sshPortForwardActivityStartedKey{}).(bool); ok && started {
return
}
ctx.SetValue(sshPortForwardActivityStartedKey{}, true)
s.startSSHActivityLoop(ctx, spritz)
}

func (s *server) findSSHGatewayPod(ctx context.Context, namespace, name, container string) (*corev1.Pod, error) {
if s.findRunningPodFunc != nil {
return s.findRunningPodFunc(ctx, namespace, name, container)
}
return s.findRunningPod(ctx, namespace, name, container)
}

func (s *server) openSSHPortForward(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) {
if s.openSSHPortForwardFunc != nil {
return s.openSSHPortForwardFunc(ctx, pod, remotePort)
}
if s.clientset == nil || s.restConfig == nil {
return nil, nil, errors.New("ssh port forwarding is not configured")
}

req := s.clientset.CoreV1().RESTClient().
Post().
Resource("pods").
Name(pod.Name).
Namespace(pod.Namespace).
SubResource("portforward")
transport, upgrader, err := transportspdy.RoundTripperFor(s.restConfig)
if err != nil {
return nil, nil, err
}
dialer := transportspdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, req.URL())
stopCh := make(chan struct{})
readyCh := make(chan struct{})
errCh := make(chan error, 1)
forwarder, err := portforward.NewOnAddresses(
dialer,
[]string{"127.0.0.1"},
[]string{fmt.Sprintf("0:%d", remotePort)},
stopCh,
readyCh,
io.Discard,
io.Discard,
)
if err != nil {
close(stopCh)
return nil, nil, err
}

go func() {
errCh <- forwarder.ForwardPorts()
}()

select {
case <-readyCh:
case err := <-errCh:
close(stopCh)
return nil, nil, err
case <-ctx.Done():
close(stopCh)
return nil, nil, ctx.Err()
}

ports, err := forwarder.GetPorts()
if err != nil {
close(stopCh)
return nil, nil, err
}
if len(ports) != 1 {
close(stopCh)
return nil, nil, fmt.Errorf("unexpected forwarded port count: %d", len(ports))
}

localAddress := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(ports[0].Local)))
upstream, err := (&net.Dialer{}).DialContext(ctx, "tcp", localAddress)
if err != nil {
close(stopCh)
return nil, nil, err
}

var once sync.Once
cleanup := closeFunc(func() error {
once.Do(func() {
close(stopCh)
})
return nil
})

go func() {
err := <-errCh
if err == nil || errors.Is(err, portforward.ErrLostConnectionToPod) || errors.Is(err, context.Canceled) {
return
}
log.Printf("spritz ssh: port-forward ended pod=%s namespace=%s remote_port=%d err=%v", pod.Name, pod.Namespace, remotePort, err)
}()

return upstream, cleanup, nil
}

func sshActivityRefreshInterval(spec spritzv1.SpritzSpec, fallback time.Duration) time.Duration {
interval := fallback
if interval <= 0 {
Expand Down
Loading
Loading