From 6371d96236a8ee02ec059c9030cfd1c7ac2a2171 Mon Sep 17 00:00:00 2001 From: Evan Phoenix Date: Fri, 5 Jun 2026 09:51:35 -0700 Subject: [PATCH] Refresh runner certificate on start when listen IP changes Distributed runner VMs are recreated with a persistent disk mounted at /var/lib/miren, which keeps runner/config.yaml (the runner ID and the certificate issued at join). On recreate the VM gets a new internal IP: the runner reattaches to its existing Node and rewrites ApiAddress to the new IP, but it keeps serving the certificate minted for the old IP. Clients (sandbox exec, app run) then fail TLS verification: tls: failed to verify certificate: x509: certificate is valid for ..., 10.128.0.45, not 10.128.0.47 Add a RefreshCertificate RPC so a runner can re-issue its certificate with SANs covering its current listen address. On start, the runner inspects its persisted certificate and, only if the discovered listen address is not covered, calls RefreshCertificate, writes the new cert back to config, and serves it. A required-but-failed refresh is fatal: serving a stale cert would leave the runner silently unreachable. Authorization is derived entirely from the caller's verified certificate, never from caller-supplied input: - the peer cert must chain to the cluster CA and be currently valid - it must be a runner cert (CN "runner-*", org "miren") - exactly one registered Node must match the CN's runner-ID prefix The last check ties refresh to live membership (caauth has no revocation), so a removed runner's still-valid cert cannot perpetually renew itself; it fails closed on zero or ambiguous matches. Supporting changes: - pkg/rpc: expose the verified peer certificate to handlers via CurrentConnectionInfo, plus ContextWithConnectionInfo for tests - pkg/caauth: add Authority.VerifyCert for an already-parsed certificate - factor the runner cert CommonName and SAN construction into shared helpers reused by Join --- api/runner/rpc.yml | 27 +++ api/runner/runner_v1alpha/rpc.gen.go | 186 +++++++++++++++++ cli/commands/runner_start.go | 146 ++++++++++++-- cli/commands/runner_start_test.go | 65 ++++++ pkg/caauth/caauth.go | 7 +- pkg/rpc/server.go | 5 +- pkg/rpc/state.go | 12 ++ servers/runner/registration.go | 191 ++++++++++++++++-- servers/runner/registration_test.go | 287 ++++++++++++++++++++++++++- 9 files changed, 887 insertions(+), 39 deletions(-) create mode 100644 cli/commands/runner_start_test.go diff --git a/api/runner/rpc.yml b/api/runner/rpc.yml index 4e667c920..7e23af79b 100644 --- a/api/runner/rpc.yml +++ b/api/runner/rpc.yml @@ -159,6 +159,33 @@ interfaces: type: string doc: Error message if removal failed + - name: RefreshCertificate + index: 6 + public: true # Authenticated in the handler via the caller's existing runner mTLS cert + doc: | + Re-issue a runner's server certificate with updated SANs. Used when a + runner's listen address changes (e.g. a VM is recreated with a new IP + but a persistent disk keeps the old, now-stale certificate). The caller + must present its existing CA-signed runner certificate; the new cert + preserves that certificate's CommonName. + parameters: + - name: listen_addr + type: string + doc: The runner's current listen address whose host should be covered by the new cert + results: + - name: cert_pem + type: bytes + doc: Re-issued certificate in PEM format + - name: key_pem + type: bytes + doc: Private key for the re-issued certificate in PEM format + - name: ca_pem + type: bytes + doc: CA certificate in PEM format + - name: error + type: string + doc: Error message if the refresh failed + types: - type: InviteInfo doc: Information about a runner invite diff --git a/api/runner/runner_v1alpha/rpc.gen.go b/api/runner/runner_v1alpha/rpc.gen.go index 4e7c875c0..c527a6c7b 100644 --- a/api/runner/runner_v1alpha/rpc.gen.go +++ b/api/runner/runner_v1alpha/rpc.gen.go @@ -927,6 +927,89 @@ func (v *RunnerRegistrationRemoveRunnerResults) UnmarshalJSON(data []byte) error return json.Unmarshal(data, &v.data) } +type runnerRegistrationRefreshCertificateArgsData struct { + ListenAddr *string `cbor:"0,keyasint,omitempty" json:"listen_addr,omitempty"` +} + +type RunnerRegistrationRefreshCertificateArgs struct { + call rpc.Call + data runnerRegistrationRefreshCertificateArgsData +} + +func (v *RunnerRegistrationRefreshCertificateArgs) HasListenAddr() bool { + return v.data.ListenAddr != nil +} + +func (v *RunnerRegistrationRefreshCertificateArgs) ListenAddr() string { + if v.data.ListenAddr == nil { + return "" + } + return *v.data.ListenAddr +} + +func (v *RunnerRegistrationRefreshCertificateArgs) MarshalCBOR() ([]byte, error) { + return cbor.Marshal(v.data) +} + +func (v *RunnerRegistrationRefreshCertificateArgs) UnmarshalCBOR(data []byte) error { + return cbor.Unmarshal(data, &v.data) +} + +func (v *RunnerRegistrationRefreshCertificateArgs) MarshalJSON() ([]byte, error) { + return json.Marshal(v.data) +} + +func (v *RunnerRegistrationRefreshCertificateArgs) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &v.data) +} + +type runnerRegistrationRefreshCertificateResultsData struct { + CertPem *[]byte `cbor:"0,keyasint,omitempty" json:"cert_pem,omitempty"` + KeyPem *[]byte `cbor:"1,keyasint,omitempty" json:"key_pem,omitempty"` + CaPem *[]byte `cbor:"2,keyasint,omitempty" json:"ca_pem,omitempty"` + Error *string `cbor:"3,keyasint,omitempty" json:"error,omitempty"` +} + +type RunnerRegistrationRefreshCertificateResults struct { + call rpc.Call + data runnerRegistrationRefreshCertificateResultsData +} + +func (v *RunnerRegistrationRefreshCertificateResults) SetCertPem(cert_pem []byte) { + x := slices.Clone(cert_pem) + v.data.CertPem = &x +} + +func (v *RunnerRegistrationRefreshCertificateResults) SetKeyPem(key_pem []byte) { + x := slices.Clone(key_pem) + v.data.KeyPem = &x +} + +func (v *RunnerRegistrationRefreshCertificateResults) SetCaPem(ca_pem []byte) { + x := slices.Clone(ca_pem) + v.data.CaPem = &x +} + +func (v *RunnerRegistrationRefreshCertificateResults) SetError(error string) { + v.data.Error = &error +} + +func (v *RunnerRegistrationRefreshCertificateResults) MarshalCBOR() ([]byte, error) { + return cbor.Marshal(v.data) +} + +func (v *RunnerRegistrationRefreshCertificateResults) UnmarshalCBOR(data []byte) error { + return cbor.Unmarshal(data, &v.data) +} + +func (v *RunnerRegistrationRefreshCertificateResults) MarshalJSON() ([]byte, error) { + return json.Marshal(v.data) +} + +func (v *RunnerRegistrationRefreshCertificateResults) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &v.data) +} + type RunnerRegistrationCreateInvite struct { rpc.Call args RunnerRegistrationCreateInviteArgs @@ -1083,6 +1166,32 @@ func (t *RunnerRegistrationRemoveRunner) Results() *RunnerRegistrationRemoveRunn return results } +type RunnerRegistrationRefreshCertificate struct { + rpc.Call + args RunnerRegistrationRefreshCertificateArgs + results RunnerRegistrationRefreshCertificateResults +} + +func (t *RunnerRegistrationRefreshCertificate) Args() *RunnerRegistrationRefreshCertificateArgs { + args := &t.args + if args.call != nil { + return args + } + args.call = t.Call + t.Call.Args(args) + return args +} + +func (t *RunnerRegistrationRefreshCertificate) Results() *RunnerRegistrationRefreshCertificateResults { + results := &t.results + if results.call != nil { + return results + } + results.call = t.Call + t.Call.Results(results) + return results +} + type RunnerRegistration interface { CreateInvite(ctx context.Context, state *RunnerRegistrationCreateInvite) error Join(ctx context.Context, state *RunnerRegistrationJoin) error @@ -1090,6 +1199,7 @@ type RunnerRegistration interface { RevokeInvite(ctx context.Context, state *RunnerRegistrationRevokeInvite) error ListRunners(ctx context.Context, state *RunnerRegistrationListRunners) error RemoveRunner(ctx context.Context, state *RunnerRegistrationRemoveRunner) error + RefreshCertificate(ctx context.Context, state *RunnerRegistrationRefreshCertificate) error } type reexportRunnerRegistration struct { @@ -1120,6 +1230,10 @@ func (reexportRunnerRegistration) RemoveRunner(ctx context.Context, state *Runne panic("not implemented") } +func (reexportRunnerRegistration) RefreshCertificate(ctx context.Context, state *RunnerRegistrationRefreshCertificate) error { + panic("not implemented") +} + func (t reexportRunnerRegistration) CapabilityClient() rpc.Client { return t.client } @@ -1180,6 +1294,15 @@ func AdaptRunnerRegistration(t RunnerRegistration) *rpc.Interface { return t.RemoveRunner(ctx, &RunnerRegistrationRemoveRunner{Call: call}) }, }, + { + Name: "RefreshCertificate", + InterfaceName: "RunnerRegistration", + Index: 6, + Public: true, + Handler: func(ctx context.Context, call rpc.Call) error { + return t.RefreshCertificate(ctx, &RunnerRegistrationRefreshCertificate{Call: call}) + }, + }, } return rpc.NewInterface(methods, t) @@ -1549,3 +1672,66 @@ func (v RunnerRegistrationClient) RemoveRunner(ctx context.Context, query string return &RunnerRegistrationClientRemoveRunnerResults{client: v.Client, data: ret}, nil } + +type RunnerRegistrationClientRefreshCertificateResults struct { + client rpc.Client + data runnerRegistrationRefreshCertificateResultsData +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) HasCertPem() bool { + return v.data.CertPem != nil +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) CertPem() []byte { + if v.data.CertPem == nil { + return nil + } + return *v.data.CertPem +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) HasKeyPem() bool { + return v.data.KeyPem != nil +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) KeyPem() []byte { + if v.data.KeyPem == nil { + return nil + } + return *v.data.KeyPem +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) HasCaPem() bool { + return v.data.CaPem != nil +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) CaPem() []byte { + if v.data.CaPem == nil { + return nil + } + return *v.data.CaPem +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) HasError() bool { + return v.data.Error != nil +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) Error() string { + if v.data.Error == nil { + return "" + } + return *v.data.Error +} + +func (v RunnerRegistrationClient) RefreshCertificate(ctx context.Context, listen_addr string) (*RunnerRegistrationClientRefreshCertificateResults, error) { + args := RunnerRegistrationRefreshCertificateArgs{} + args.data.ListenAddr = &listen_addr + + var ret runnerRegistrationRefreshCertificateResultsData + + err := v.Call(ctx, "RefreshCertificate", &args, &ret) + if err != nil { + return nil, err + } + + return &RunnerRegistrationClientRefreshCertificateResults{client: v.Client, data: ret}, nil +} diff --git a/cli/commands/runner_start.go b/cli/commands/runner_start.go index fae66fad0..f35253529 100644 --- a/cli/commands/runner_start.go +++ b/cli/commands/runner_start.go @@ -4,6 +4,8 @@ package commands import ( "context" + "crypto/x509" + "encoding/pem" "fmt" "net" "net/netip" @@ -16,6 +18,7 @@ import ( containerd "github.com/containerd/containerd/v2/client" "golang.org/x/sync/errgroup" + "miren.dev/runtime/api/runner/runner_v1alpha" "miren.dev/runtime/clientconfig" containerdcomp "miren.dev/runtime/components/containerd" "miren.dev/runtime/components/netresolve" @@ -23,6 +26,7 @@ import ( "miren.dev/runtime/controllers/sandbox" "miren.dev/runtime/metrics" "miren.dev/runtime/observability" + "miren.dev/runtime/pkg/rpc" "miren.dev/runtime/pkg/runnerconfig" ) @@ -44,6 +48,30 @@ func RunnerStart(ctx *Context, opts struct { "etcd_endpoints", cfg.EtcdEndpoints, "network_backend", cfg.NetworkBackend) + // Determine listen address. If no explicit address is given, discover the + // machine's outbound IP (the one that would route to the coordinator) and + // advertise that so the coordinator knows how to reach this runner. + listenAddr := opts.ListenAddr + if listenAddr == "" { + port := "8444" + ip, err := discoverOutboundIP(cfg.CoordinatorAddress) + if err != nil { + return fmt.Errorf("could not discover outbound IP for listen address (use --listen to set manually): %w", err) + } + listenAddr = net.JoinHostPort(ip.String(), port) + ctx.Log.Info("discovered listen address", "addr", listenAddr) + } + + // The runner's certificate is persisted in the config (often on a disk that + // outlives the VM). If the listen address has changed since the cert was + // issued, the persisted cert no longer covers our IP and clients (e.g. + // sandbox exec) will fail TLS verification. Re-issue the cert before we use + // it to serve. Refreshing is fatal when the address isn't covered: serving a + // stale cert would leave the runner silently unreachable. + if err := ensureRunnerCertificate(ctx, cfg, opts.ConfigPath, listenAddr); err != nil { + return err + } + // Create clientconfig from saved certs for RPC authentication clientCfg := clientconfig.NewConfig() clientCfg.SetCluster("coordinator", &clientconfig.ClusterConfig{ @@ -137,20 +165,6 @@ func RunnerStart(ctx *Context, opts struct { // Create errgroup for background tasks eg, egCtx := errgroup.WithContext(sigCtx) - // Determine listen address. If no explicit address is given, discover the - // machine's outbound IP (the one that would route to the coordinator) and - // advertise that so the coordinator knows how to reach this runner. - listenAddr := opts.ListenAddr - if listenAddr == "" { - port := "8444" - ip, err := discoverOutboundIP(cfg.CoordinatorAddress) - if err != nil { - return fmt.Errorf("could not discover outbound IP for listen address (use --listen to set manually): %w", err) - } - listenAddr = net.JoinHostPort(ip.String(), port) - ctx.Log.Info("discovered listen address", "addr", listenAddr) - } - // Build runner configuration runnerCfg := runner.RunnerConfig{ Id: cfg.RunnerID, @@ -300,3 +314,107 @@ func RunnerStart(ctx *Context, opts struct { ctx.Log.Info("runner stopped") return nil } + +// ensureRunnerCertificate re-issues the runner's server certificate when the +// current listen address is not covered by the persisted certificate's SANs. +// This happens when a VM is recreated with a new IP but a persistent disk keeps +// the old config (cert + runner ID). The refreshed cert is written back to the +// config so it survives subsequent restarts. A required-but-failed refresh is +// fatal: serving a stale cert would leave the runner silently unreachable. +func ensureRunnerCertificate(ctx *Context, cfg *runnerconfig.Config, configPath, listenAddr string) error { + if cfg.ClientCert == "" { + return nil + } + + covered, err := certCoversListenAddr(cfg.ClientCert, listenAddr) + if err != nil { + return fmt.Errorf("failed to inspect runner certificate: %w", err) + } + if covered { + return nil + } + + ctx.Log.Info("runner certificate does not cover listen address; refreshing", + "listen_addr", listenAddr) + + cs, err := rpc.NewState(ctx, + rpc.WithLogger(ctx.Log), + rpc.WithBindAddr("[::]:0"), + rpc.WithCertPEMs([]byte(cfg.ClientCert), []byte(cfg.ClientKey)), + rpc.WithCertificateVerification([]byte(cfg.CACert)), + ) + if err != nil { + return fmt.Errorf("failed to create RPC state for certificate refresh: %w", err) + } + defer cs.Close() + + client, err := cs.Connect(cfg.CoordinatorAddress, rpc.ServiceRunner) + if err != nil { + return fmt.Errorf("failed to connect to coordinator for certificate refresh: %w", err) + } + defer client.Close() + + rc := runner_v1alpha.NewRunnerRegistrationClient(client) + res, err := rc.RefreshCertificate(ctx, listenAddr) + if err != nil { + return fmt.Errorf("certificate refresh request failed: %w", err) + } + if res.Error() != "" { + return fmt.Errorf("certificate refresh rejected by coordinator: %s", res.Error()) + } + if len(res.CertPem()) == 0 || len(res.KeyPem()) == 0 { + return fmt.Errorf("coordinator returned an empty certificate") + } + + cfg.ClientCert = string(res.CertPem()) + cfg.ClientKey = string(res.KeyPem()) + if len(res.CaPem()) > 0 { + cfg.CACert = string(res.CaPem()) + } + + if err := cfg.Save(configPath); err != nil { + return fmt.Errorf("failed to save refreshed certificate: %w", err) + } + + ctx.Log.Info("runner certificate refreshed", "listen_addr", listenAddr) + return nil +} + +// certCoversListenAddr reports whether the leaf certificate in certPEM carries a +// SAN matching the host of listenAddr: an IP SAN for an IP host, or a DNS SAN +// for a hostname. This mirrors how the coordinator builds the certificate's SANs +// from the listen address. +func certCoversListenAddr(certPEM, listenAddr string) (bool, error) { + host, _, err := net.SplitHostPort(listenAddr) + if err != nil { + host = listenAddr + } + if host == "" { + return true, nil + } + + block, _ := pem.Decode([]byte(certPEM)) + if block == nil { + return false, fmt.Errorf("no PEM block found in certificate") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false, fmt.Errorf("parsing certificate: %w", err) + } + + if ip := net.ParseIP(host); ip != nil { + for _, certIP := range cert.IPAddresses { + if certIP.Equal(ip) { + return true, nil + } + } + return false, nil + } + + for _, name := range cert.DNSNames { + if name == host { + return true, nil + } + } + return false, nil +} diff --git a/cli/commands/runner_start_test.go b/cli/commands/runner_start_test.go new file mode 100644 index 000000000..a63d9ee0a --- /dev/null +++ b/cli/commands/runner_start_test.go @@ -0,0 +1,65 @@ +//go:build linux + +package commands + +import ( + "net" + "testing" + "time" + + "miren.dev/runtime/pkg/caauth" +) + +func TestCertCoversListenAddr(t *testing.T) { + ca, err := caauth.New(caauth.Options{ + CommonName: "test-ca", + Organization: "test", + ValidFor: 24 * time.Hour, + }) + if err != nil { + t.Fatalf("failed to create CA: %v", err) + } + + cc, err := ca.IssueCertificate(caauth.Options{ + CommonName: "runner-abc12345", + Organization: "miren", + ValidFor: time.Hour, + IPs: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("10.0.0.45")}, + DNSNames: []string{"localhost", "runner.example.com"}, + }) + if err != nil { + t.Fatalf("failed to issue cert: %v", err) + } + certPEM := string(cc.CertPEM) + + tests := []struct { + name string + listenAddr string + want bool + }{ + {"covered IP", "10.0.0.45:8444", true}, + {"loopback IP", "127.0.0.1:8444", true}, + {"changed IP not covered", "10.0.0.47:8444", false}, + {"covered DNS host", "runner.example.com:8444", true}, + {"uncovered DNS host", "other.example.com:8444", false}, + {"bare IP without port", "10.0.0.45", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := certCoversListenAddr(certPEM, tt.listenAddr) + if err != nil { + t.Fatalf("certCoversListenAddr returned error: %v", err) + } + if got != tt.want { + t.Errorf("certCoversListenAddr(%q) = %v, want %v", tt.listenAddr, got, tt.want) + } + }) + } +} + +func TestCertCoversListenAddrInvalidPEM(t *testing.T) { + if _, err := certCoversListenAddr("not a pem", "10.0.0.1:8444"); err == nil { + t.Fatal("expected error for invalid PEM input") + } +} diff --git a/pkg/caauth/caauth.go b/pkg/caauth/caauth.go index 21f5770fb..888e3b0dd 100644 --- a/pkg/caauth/caauth.go +++ b/pkg/caauth/caauth.go @@ -291,12 +291,17 @@ func (ca *Authority) VerifyCertificate(certPEM []byte) error { return fmt.Errorf("parsing certificate: %w", err) } + return ca.VerifyCert(cert) +} + +// VerifyCert verifies that an already-parsed certificate was signed by this CA. +func (ca *Authority) VerifyCert(cert *x509.Certificate) error { // Create verification pool with CA cert roots := x509.NewCertPool() roots.AddCert(ca.cert) // Verify the certificate - _, err = cert.Verify(x509.VerifyOptions{ + _, err := cert.Verify(x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go index 4bbd918a4..b128d1450 100644 --- a/pkg/rpc/server.go +++ b/pkg/rpc/server.go @@ -968,8 +968,9 @@ func (s *Server) startCallStream(w http.ResponseWriter, r *http.Request) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { call.peer = r.TLS.PeerCertificates[0] - ctx = context.WithValue(ctx, connectionKey{}, &CurrentConnectionInfo{ - PeerSubject: r.TLS.PeerCertificates[0].Subject.String(), + ctx = ContextWithConnectionInfo(ctx, &CurrentConnectionInfo{ + PeerSubject: r.TLS.PeerCertificates[0].Subject.String(), + PeerCertificate: r.TLS.PeerCertificates[0], }) } diff --git a/pkg/rpc/state.go b/pkg/rpc/state.go index 042d381c3..baf500e38 100644 --- a/pkg/rpc/state.go +++ b/pkg/rpc/state.go @@ -431,6 +431,11 @@ type connectionKey struct{} type CurrentConnectionInfo struct { PeerSubject string + // PeerCertificate is the client certificate presented during the mTLS + // handshake, if any. The server is configured with tls.RequestClientCert, + // which requests but does not verify the cert, so handlers that rely on it + // for authorization must verify it (e.g. that it chains to the cluster CA). + PeerCertificate *x509.Certificate } func ConnectionInfo(ctx context.Context) *CurrentConnectionInfo { @@ -442,6 +447,13 @@ func ConnectionInfo(ctx context.Context) *CurrentConnectionInfo { return v.(*CurrentConnectionInfo) } +// ContextWithConnectionInfo returns a copy of ctx carrying the given connection +// info, as observed by ConnectionInfo. The server populates this from the mTLS +// handshake; tests use it to exercise handlers that authorize on the peer cert. +func ContextWithConnectionInfo(ctx context.Context, info *CurrentConnectionInfo) context.Context { + return context.WithValue(ctx, connectionKey{}, info) +} + func (s *State) startListener(ctx context.Context, so *stateOptions) error { err := s.setupServer(so) if err != nil { diff --git a/servers/runner/registration.go b/servers/runner/registration.go index 28d803050..26205eabc 100644 --- a/servers/runner/registration.go +++ b/servers/runner/registration.go @@ -2,9 +2,11 @@ package runner import ( "context" + "crypto/x509" "fmt" "log/slog" "net" + "slices" "strings" "time" @@ -20,6 +22,7 @@ import ( "miren.dev/runtime/pkg/entity" "miren.dev/runtime/pkg/entity/types" "miren.dev/runtime/pkg/joincode" + "miren.dev/runtime/pkg/rpc" "miren.dev/runtime/pkg/rpc/standard" ) @@ -240,28 +243,9 @@ func (s *RegistrationServer) Join(ctx context.Context, req *runner_v1alpha.Runne // Now that invite is claimed, issue the certificate with proper SANs // so the coordinator can connect to the runner's API by IP. - runnerIDPrefix := runnerID - if len(runnerIDPrefix) > 8 { - runnerIDPrefix = runnerIDPrefix[:8] - } - certName := fmt.Sprintf("runner-%s", runnerIDPrefix) + certName := runnerCertName(runnerID) - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("::1"), - } - dnsNames := []string{"localhost"} - - if listenAddr != "" { - host, _, err := net.SplitHostPort(listenAddr) - if err == nil && host != "" { - if ip := net.ParseIP(host); ip != nil { - ips = append(ips, ip) - } else if host != "localhost" { - dnsNames = append(dnsNames, host) - } - } - } + ips, dnsNames := buildRunnerSANs(listenAddr) cc, err := s.Authority.IssueCertificate(caauth.Options{ CommonName: certName, @@ -360,6 +344,151 @@ func (s *RegistrationServer) Join(ctx context.Context, req *runner_v1alpha.Runne return nil } +// runnerCertName returns the CommonName the coordinator assigns to a runner's +// certificate. It is derived solely from the (UUID) runner ID, never from any +// client-supplied name, so the "runner-" prefix is a trustworthy marker that a +// certificate was issued to a runner. +func runnerCertName(runnerID string) string { + prefix := runnerID + if len(prefix) > 8 { + prefix = prefix[:8] + } + return fmt.Sprintf("runner-%s", prefix) +} + +// buildRunnerSANs returns the IP and DNS subject alternative names a runner's +// server certificate should carry: always loopback (127.0.0.1, ::1, localhost) +// plus the host from the runner's advertised listen address (an IP becomes an +// IP SAN, a hostname becomes a DNS SAN). +func buildRunnerSANs(listenAddr string) ([]net.IP, []string) { + ips := []net.IP{ + net.ParseIP("127.0.0.1"), + net.ParseIP("::1"), + } + dnsNames := []string{"localhost"} + + if listenAddr != "" { + host, _, err := net.SplitHostPort(listenAddr) + if err == nil && host != "" { + if ip := net.ParseIP(host); ip != nil { + ips = append(ips, ip) + } else if host != "localhost" { + dnsNames = append(dnsNames, host) + } + } + } + + return ips, dnsNames +} + +// RefreshCertificate re-issues the calling runner's server certificate with SANs +// derived from its current listen address. A runner needs this when its listen +// IP changes but its persisted certificate (e.g. on a disk that outlives the VM) +// still carries the old IP. The method is public at the RPC layer but authorizes +// the caller here: the presented client certificate must chain to the cluster CA +// and be a runner certificate, and the re-issued certificate keeps that +// certificate's CommonName so a runner can only refresh its own identity. +func (s *RegistrationServer) RefreshCertificate(ctx context.Context, req *runner_v1alpha.RunnerRegistrationRefreshCertificate) error { + args := req.Args() + results := req.Results() + + info := rpc.ConnectionInfo(ctx) + var peer *x509.Certificate + if info != nil { + peer = info.PeerCertificate + } + + listenAddr := "" + if args.HasListenAddr() { + listenAddr = args.ListenAddr() + } + + cc, err := s.reissueRunnerCertificate(ctx, peer, listenAddr) + if err != nil { + results.SetError(err.Error()) + return nil + } + + results.SetCertPem(cc.CertPEM) + results.SetKeyPem(cc.KeyPEM) + results.SetCaPem(cc.CACert) + + return nil +} + +// reissueRunnerCertificate authorizes the caller solely by its presented client +// certificate and, if valid, issues a fresh runner server certificate with SANs +// derived from listenAddr. The new certificate keeps the caller's CommonName so +// a runner can only refresh its own identity. Authorization requires that the +// presented certificate is a CA-signed runner certificate and that a runner +// matching its identity is still registered (so a removed runner's still-valid +// certificate cannot perpetually renew itself). The runner identity is taken +// from the verified certificate, never from caller-supplied input. The returned +// error is safe to surface to the caller. +func (s *RegistrationServer) reissueRunnerCertificate(ctx context.Context, peer *x509.Certificate, listenAddr string) (*caauth.ClientCertificate, error) { + if peer == nil { + return nil, fmt.Errorf("a client certificate is required to refresh a certificate") + } + + if err := s.Authority.VerifyCert(peer); err != nil { + s.Log.Warn("RefreshCertificate rejected: peer cert not signed by cluster CA", + "error", err, "subject", peer.Subject.String()) + return nil, fmt.Errorf("client certificate is not trusted") + } + + commonName := peer.Subject.CommonName + idPrefix, ok := strings.CutPrefix(commonName, "runner-") + if !ok || idPrefix == "" || !slices.Contains(peer.Subject.Organization, "miren") { + s.Log.Warn("RefreshCertificate rejected: peer cert is not a runner certificate", + "subject", peer.Subject.String()) + return nil, fmt.Errorf("client certificate is not a runner certificate") + } + + // Confirm a runner matching the certificate's identity is still registered. + // caauth has no revocation, so this is what prevents a removed runner's + // still-valid certificate from renewing itself indefinitely. The certificate + // CommonName carries only the runner ID prefix, so we match registered nodes + // by that prefix and fail closed unless exactly one matches. + matches, err := s.findNodesByRunnerIDPrefix(ctx, idPrefix) + if err != nil { + s.Log.Error("RefreshCertificate failed to verify runner registration", + "error", err, "subject", peer.Subject.String()) + return nil, fmt.Errorf("failed to verify runner registration") + } + switch len(matches) { + case 0: + s.Log.Warn("RefreshCertificate rejected: runner is not registered", + "subject", peer.Subject.String()) + return nil, fmt.Errorf("runner is not registered") + case 1: + // authorized + default: + s.Log.Warn("RefreshCertificate rejected: runner identity is ambiguous", + "subject", peer.Subject.String(), "matches", len(matches)) + return nil, fmt.Errorf("runner identity is ambiguous") + } + + ips, dnsNames := buildRunnerSANs(listenAddr) + + cc, err := s.Authority.IssueCertificate(caauth.Options{ + CommonName: commonName, + Organization: "miren", + ValidFor: 365 * 24 * time.Hour, + IPs: ips, + DNSNames: dnsNames, + }) + if err != nil { + s.Log.Error("Failed to re-issue certificate", "error", err, "common_name", commonName) + return nil, fmt.Errorf("failed to issue certificate") + } + + s.Log.Info("Re-issued runner certificate", + "common_name", commonName, + "listen_addr", listenAddr) + + return cc, nil +} + func (s *RegistrationServer) ListInvites(ctx context.Context, req *runner_v1alpha.RunnerRegistrationListInvites) error { results := req.Results() @@ -655,6 +784,26 @@ func (s *RegistrationServer) findNodeByQuery(ctx context.Context, query string) } } +// findNodesByRunnerIDPrefix returns all registered nodes whose runner ID begins +// with the given prefix. Used to map a runner certificate's CommonName (which +// carries only the runner ID prefix) back to its registration. +func (s *RegistrationServer) findNodesByRunnerIDPrefix(ctx context.Context, prefix string) ([]compute_v1alpha.Node, error) { + listResp, err := s.EAC.List(ctx, entity.Ref(entity.EntityKind, compute_v1alpha.KindNode)) + if err != nil { + return nil, err + } + + var matches []compute_v1alpha.Node + for _, e := range listResp.Values() { + var node compute_v1alpha.Node + decodeEntity(e, &node) + if node.RunnerId != "" && strings.HasPrefix(node.RunnerId, prefix) { + matches = append(matches, node) + } + } + return matches, nil +} + func (s *RegistrationServer) countNodeSchedules(ctx context.Context, nodeID entity.Id) (int, error) { listResp, err := s.EAC.List(ctx, compute_v1alpha.Index(compute_v1alpha.KindSandbox, nodeID)) if err != nil { diff --git a/servers/runner/registration_test.go b/servers/runner/registration_test.go index 1e6c0b614..962500d84 100644 --- a/servers/runner/registration_test.go +++ b/servers/runner/registration_test.go @@ -6,6 +6,7 @@ import ( "encoding/pem" "net" "testing" + "time" "miren.dev/runtime/api/runner/runner_v1alpha" "miren.dev/runtime/pkg/caauth" @@ -18,6 +19,8 @@ import ( type testEnv struct { client *runner_v1alpha.RunnerRegistrationClient store *entity.MockStore + server *RegistrationServer + ca *caauth.Authority } func newTestServer(t *testing.T) (*testEnv, func()) { @@ -28,6 +31,7 @@ func newTestServer(t *testing.T) (*testEnv, func()) { ca, err := caauth.New(caauth.Options{ CommonName: "test-ca", Organization: "test", + ValidFor: 24 * time.Hour, }) if err != nil { cleanup() @@ -44,7 +48,38 @@ func newTestServer(t *testing.T) (*testEnv, func()) { localClient := rpc.LocalClient(runner_v1alpha.AdaptRunnerRegistration(regServer)) client := runner_v1alpha.NewRunnerRegistrationClient(localClient) - return &testEnv{client: client, store: es.Store}, cleanup + return &testEnv{client: client, store: es.Store, server: regServer, ca: ca}, cleanup +} + +// issueLeafCert issues a certificate from the given authority and returns the +// parsed leaf certificate (the first PEM block; IssueCertificate appends the CA +// cert after it). +func issueLeafCert(t *testing.T, ca *caauth.Authority, commonName, org string, ip string) *x509.Certificate { + t.Helper() + + opts := caauth.Options{ + CommonName: commonName, + Organization: org, + ValidFor: time.Hour, + } + if ip != "" { + opts.IPs = []net.IP{net.ParseIP(ip)} + } + + cc, err := ca.IssueCertificate(opts) + if err != nil { + t.Fatalf("failed to issue cert: %v", err) + } + + block, _ := pem.Decode(cc.CertPEM) + if block == nil { + t.Fatal("failed to decode issued cert PEM") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("failed to parse issued cert: %v", err) + } + return cert } // createInviteAndDecode creates a one-time invite and returns the secret. @@ -468,3 +503,253 @@ func TestRemoveRunnerByRunnerId(t *testing.T) { t.Errorf("expected 0 runners, got %d", len(listResult.Runners())) } } + +func TestBuildRunnerSANs(t *testing.T) { + t.Run("IP listen address adds IP SAN", func(t *testing.T) { + ips, dnsNames := buildRunnerSANs("10.0.0.7:8444") + + wantIPs := []string{"127.0.0.1", "::1", "10.0.0.7"} + for _, want := range wantIPs { + if !containsIP(ips, want) { + t.Errorf("missing IP SAN %s, got %v", want, ips) + } + } + if !containsStr(dnsNames, "localhost") { + t.Errorf("missing DNS SAN localhost, got %v", dnsNames) + } + if containsStr(dnsNames, "10.0.0.7") { + t.Errorf("IP should not appear as a DNS SAN, got %v", dnsNames) + } + }) + + t.Run("hostname listen address adds DNS SAN", func(t *testing.T) { + ips, dnsNames := buildRunnerSANs("runner.example.com:8444") + + if !containsStr(dnsNames, "runner.example.com") { + t.Errorf("missing DNS SAN runner.example.com, got %v", dnsNames) + } + if !containsIP(ips, "127.0.0.1") || !containsIP(ips, "::1") { + t.Errorf("missing loopback IP SANs, got %v", ips) + } + }) + + t.Run("empty listen address yields only loopback", func(t *testing.T) { + ips, dnsNames := buildRunnerSANs("") + if len(ips) != 2 { + t.Errorf("expected only loopback IPs, got %v", ips) + } + if len(dnsNames) != 1 || dnsNames[0] != "localhost" { + t.Errorf("expected only localhost DNS, got %v", dnsNames) + } + }) +} + +// joinRunner performs a Join and returns the issued leaf certificate and the +// assigned runner ID, so tests can exercise refresh as a real, registered runner. +func (e *testEnv) joinRunner(t *testing.T, ctx context.Context, listenAddr, name string) (*x509.Certificate, string) { + t.Helper() + + secret := e.createInviteAndDecode(t, ctx) + res, err := e.client.Join(ctx, secret, "", listenAddr, "v1", nil, name) + if err != nil { + t.Fatalf("Join failed: %v", err) + } + if res.HasError() { + t.Fatalf("Join returned error: %s", res.Error()) + } + + block, _ := pem.Decode(res.CertPem()) + if block == nil { + t.Fatal("failed to decode join cert PEM") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("failed to parse join cert: %v", err) + } + return cert, res.RunnerId() +} + +func TestReissueRunnerCertificate(t *testing.T) { + ctx := context.Background() + env, cleanup := newTestServer(t) + defer cleanup() + + t.Run("happy path re-issues with new IP and preserves CN", func(t *testing.T) { + peer, _ := env.joinRunner(t, ctx, "10.0.0.1:8444", "happy-runner") + + cc, err := env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err != nil { + t.Fatalf("reissueRunnerCertificate failed: %v", err) + } + + // The re-issued cert must be signed by the cluster CA. + if err := env.ca.VerifyCertificate(cc.CertPEM); err != nil { + t.Fatalf("re-issued cert not signed by CA: %v", err) + } + + block, _ := pem.Decode(cc.CertPEM) + if block == nil { + t.Fatal("failed to decode re-issued cert PEM") + } + newCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("failed to parse re-issued cert: %v", err) + } + + if newCert.Subject.CommonName != peer.Subject.CommonName { + t.Errorf("CommonName = %q, want preserved %q", newCert.Subject.CommonName, peer.Subject.CommonName) + } + + var foundIPs []string + for _, ip := range newCert.IPAddresses { + foundIPs = append(foundIPs, ip.String()) + } + if !containsIP(newCert.IPAddresses, "10.0.0.2") { + t.Errorf("re-issued cert missing new IP SAN 10.0.0.2, got %v", foundIPs) + } + }) + + t.Run("nil peer is rejected", func(t *testing.T) { + _, err := env.server.reissueRunnerCertificate(ctx, nil, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for nil peer certificate") + } + }) + + t.Run("cert from another CA is rejected", func(t *testing.T) { + otherCA, err := caauth.New(caauth.Options{CommonName: "other-ca", Organization: "other", ValidFor: 24 * time.Hour}) + if err != nil { + t.Fatalf("failed to create other CA: %v", err) + } + peer := issueLeafCert(t, otherCA, "runner-abc12345", "miren", "10.0.0.1") + + _, err = env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for cert signed by a different CA") + } + }) + + t.Run("non-runner CN is rejected", func(t *testing.T) { + peer := issueLeafCert(t, env.ca, "operator-abc", "miren", "10.0.0.1") + + _, err := env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for non-runner CommonName") + } + }) + + t.Run("wrong organization is rejected", func(t *testing.T) { + peer := issueLeafCert(t, env.ca, "runner-abc12345", "intruder", "10.0.0.1") + + _, err := env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for cert with wrong organization") + } + }) + + t.Run("unregistered runner cert is rejected", func(t *testing.T) { + // A genuine CA-signed runner cert, but no matching Node exists (e.g. the + // runner was never registered or has been removed). The identity comes + // from the cert, so a caller cannot substitute another runner's ID. + peer := issueLeafCert(t, env.ca, "runner-deadbeef", "miren", "10.0.0.1") + + _, err := env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for a runner cert with no registered node") + } + }) + + t.Run("removed runner cannot refresh", func(t *testing.T) { + peer, runnerID := env.joinRunner(t, ctx, "10.0.0.1:8444", "removed-runner") + + // Remove the runner, deleting its Node entity. + removeRes, err := env.client.RemoveRunner(ctx, runnerID, false) + if err != nil { + t.Fatalf("RemoveRunner failed: %v", err) + } + if removeRes.Error() != "" { + t.Fatalf("RemoveRunner returned error: %s", removeRes.Error()) + } + + // Its certificate is still cryptographically valid, but it is no longer + // registered, so refresh must be rejected. + _, err = env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error refreshing a removed runner's certificate") + } + }) + + t.Run("ambiguous runner identity is rejected", func(t *testing.T) { + // Two runners whose IDs share the same 8-char prefix produce the same + // certificate CommonName. Refresh must fail closed rather than guess. + secret1 := env.createInviteAndDecode(t, ctx) + res1, err := env.client.Join(ctx, secret1, "abcd1234-0000-0000-0000-000000000001", "10.0.0.1:8444", "v1", nil, "amb-1") + if err != nil { + t.Fatalf("Join 1 failed: %v", err) + } + if res1.HasError() { + t.Fatalf("Join 1 error: %s", res1.Error()) + } + + secret2 := env.createInviteAndDecode(t, ctx) + res2, err := env.client.Join(ctx, secret2, "abcd1234-0000-0000-0000-000000000002", "10.0.0.1:8444", "v1", nil, "amb-2") + if err != nil { + t.Fatalf("Join 2 failed: %v", err) + } + if res2.HasError() { + t.Fatalf("Join 2 error: %s", res2.Error()) + } + + block, _ := pem.Decode(res1.CertPem()) + if block == nil { + t.Fatal("failed to decode join cert PEM") + } + peer, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("failed to parse join cert: %v", err) + } + + _, err = env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for ambiguous runner identity") + } + }) +} + +func TestRefreshCertificateRequiresClientCert(t *testing.T) { + ctx := context.Background() + env, cleanup := newTestServer(t) + defer cleanup() + + // The local RPC client does not carry a TLS peer certificate, so the + // handler must reject the request rather than minting a cert. + res, err := env.client.RefreshCertificate(ctx, "10.0.0.2:8444") + if err != nil { + t.Fatalf("RefreshCertificate RPC failed: %v", err) + } + if res.Error() == "" { + t.Fatal("expected RefreshCertificate to reject a call without a client certificate") + } + if len(res.CertPem()) != 0 { + t.Error("expected no certificate when the call is rejected") + } +} + +func containsIP(ips []net.IP, want string) bool { + w := net.ParseIP(want) + for _, ip := range ips { + if ip.Equal(w) { + return true + } + } + return false +} + +func containsStr(ss []string, want string) bool { + for _, s := range ss { + if s == want { + return true + } + } + return false +}