diff --git a/api/runner/rpc.yml b/api/runner/rpc.yml index 5bea3cb20..dfd2bdb83 100644 --- a/api/runner/rpc.yml +++ b/api/runner/rpc.yml @@ -200,6 +200,33 @@ interfaces: type: string doc: Error message if issuance failed + - name: RefreshCertificate + index: 8 + public: true # Authenticated in the handler via the caller's existing runner mTLS cert + doc: | + Re-issue a runner's server certificate with updated SANs. Used when a + runner's listen address changes (e.g. a VM is recreated with a new IP + but a persistent disk keeps the old, now-stale certificate). The caller + must present its existing CA-signed runner certificate; the new cert + preserves that certificate's CommonName. + parameters: + - name: listen_addr + type: string + doc: The runner's current listen address whose host should be covered by the new cert + results: + - name: cert_pem + type: bytes + doc: Re-issued certificate in PEM format + - name: key_pem + type: bytes + doc: Private key for the re-issued certificate in PEM format + - name: ca_pem + type: bytes + doc: CA certificate in PEM format + - name: error + type: string + doc: Error message if the refresh failed + types: - type: InviteInfo doc: Information about a runner invite diff --git a/api/runner/runner_v1alpha/rpc.gen.go b/api/runner/runner_v1alpha/rpc.gen.go index 6d2198d96..0a6a341a9 100644 --- a/api/runner/runner_v1alpha/rpc.gen.go +++ b/api/runner/runner_v1alpha/rpc.gen.go @@ -1078,6 +1078,89 @@ func (v *RunnerRegistrationIssueWorkloadTokenResults) UnmarshalJSON(data []byte) return json.Unmarshal(data, &v.data) } +type runnerRegistrationRefreshCertificateArgsData struct { + ListenAddr *string `cbor:"0,keyasint,omitempty" json:"listen_addr,omitempty"` +} + +type RunnerRegistrationRefreshCertificateArgs struct { + call rpc.Call + data runnerRegistrationRefreshCertificateArgsData +} + +func (v *RunnerRegistrationRefreshCertificateArgs) HasListenAddr() bool { + return v.data.ListenAddr != nil +} + +func (v *RunnerRegistrationRefreshCertificateArgs) ListenAddr() string { + if v.data.ListenAddr == nil { + return "" + } + return *v.data.ListenAddr +} + +func (v *RunnerRegistrationRefreshCertificateArgs) MarshalCBOR() ([]byte, error) { + return cbor.Marshal(v.data) +} + +func (v *RunnerRegistrationRefreshCertificateArgs) UnmarshalCBOR(data []byte) error { + return cbor.Unmarshal(data, &v.data) +} + +func (v *RunnerRegistrationRefreshCertificateArgs) MarshalJSON() ([]byte, error) { + return json.Marshal(v.data) +} + +func (v *RunnerRegistrationRefreshCertificateArgs) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &v.data) +} + +type runnerRegistrationRefreshCertificateResultsData struct { + CertPem *[]byte `cbor:"0,keyasint,omitempty" json:"cert_pem,omitempty"` + KeyPem *[]byte `cbor:"1,keyasint,omitempty" json:"key_pem,omitempty"` + CaPem *[]byte `cbor:"2,keyasint,omitempty" json:"ca_pem,omitempty"` + Error *string `cbor:"3,keyasint,omitempty" json:"error,omitempty"` +} + +type RunnerRegistrationRefreshCertificateResults struct { + call rpc.Call + data runnerRegistrationRefreshCertificateResultsData +} + +func (v *RunnerRegistrationRefreshCertificateResults) SetCertPem(cert_pem []byte) { + x := slices.Clone(cert_pem) + v.data.CertPem = &x +} + +func (v *RunnerRegistrationRefreshCertificateResults) SetKeyPem(key_pem []byte) { + x := slices.Clone(key_pem) + v.data.KeyPem = &x +} + +func (v *RunnerRegistrationRefreshCertificateResults) SetCaPem(ca_pem []byte) { + x := slices.Clone(ca_pem) + v.data.CaPem = &x +} + +func (v *RunnerRegistrationRefreshCertificateResults) SetError(error string) { + v.data.Error = &error +} + +func (v *RunnerRegistrationRefreshCertificateResults) MarshalCBOR() ([]byte, error) { + return cbor.Marshal(v.data) +} + +func (v *RunnerRegistrationRefreshCertificateResults) UnmarshalCBOR(data []byte) error { + return cbor.Unmarshal(data, &v.data) +} + +func (v *RunnerRegistrationRefreshCertificateResults) MarshalJSON() ([]byte, error) { + return json.Marshal(v.data) +} + +func (v *RunnerRegistrationRefreshCertificateResults) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &v.data) +} + type RunnerRegistrationCreateInvite struct { rpc.Call args RunnerRegistrationCreateInviteArgs @@ -1286,6 +1369,32 @@ func (t *RunnerRegistrationIssueWorkloadToken) Results() *RunnerRegistrationIssu return results } +type RunnerRegistrationRefreshCertificate struct { + rpc.Call + args RunnerRegistrationRefreshCertificateArgs + results RunnerRegistrationRefreshCertificateResults +} + +func (t *RunnerRegistrationRefreshCertificate) Args() *RunnerRegistrationRefreshCertificateArgs { + args := &t.args + if args.call != nil { + return args + } + args.call = t.Call + t.Call.Args(args) + return args +} + +func (t *RunnerRegistrationRefreshCertificate) Results() *RunnerRegistrationRefreshCertificateResults { + results := &t.results + if results.call != nil { + return results + } + results.call = t.Call + t.Call.Results(results) + return results +} + type RunnerRegistration interface { CreateInvite(ctx context.Context, state *RunnerRegistrationCreateInvite) error Join(ctx context.Context, state *RunnerRegistrationJoin) error @@ -1295,6 +1404,7 @@ type RunnerRegistration interface { RemoveRunner(ctx context.Context, state *RunnerRegistrationRemoveRunner) error WorkloadIssuerInfo(ctx context.Context, state *RunnerRegistrationWorkloadIssuerInfo) error IssueWorkloadToken(ctx context.Context, state *RunnerRegistrationIssueWorkloadToken) error + RefreshCertificate(ctx context.Context, state *RunnerRegistrationRefreshCertificate) error } type reexportRunnerRegistration struct { @@ -1333,6 +1443,10 @@ func (reexportRunnerRegistration) IssueWorkloadToken(ctx context.Context, state panic("not implemented") } +func (reexportRunnerRegistration) RefreshCertificate(ctx context.Context, state *RunnerRegistrationRefreshCertificate) error { + panic("not implemented") +} + func (t reexportRunnerRegistration) CapabilityClient() rpc.Client { return t.client } @@ -1411,6 +1525,15 @@ func AdaptRunnerRegistration(t RunnerRegistration) *rpc.Interface { return t.IssueWorkloadToken(ctx, &RunnerRegistrationIssueWorkloadToken{Call: call}) }, }, + { + Name: "RefreshCertificate", + InterfaceName: "RunnerRegistration", + Index: 8, + Public: true, + Handler: func(ctx context.Context, call rpc.Call) error { + return t.RefreshCertificate(ctx, &RunnerRegistrationRefreshCertificate{Call: call}) + }, + }, } return rpc.NewInterface(methods, t) @@ -1864,3 +1987,66 @@ func (v RunnerRegistrationClient) IssueWorkloadToken(ctx context.Context, sandbo return &RunnerRegistrationClientIssueWorkloadTokenResults{client: v.Client, data: ret}, nil } + +type RunnerRegistrationClientRefreshCertificateResults struct { + client rpc.Client + data runnerRegistrationRefreshCertificateResultsData +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) HasCertPem() bool { + return v.data.CertPem != nil +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) CertPem() []byte { + if v.data.CertPem == nil { + return nil + } + return *v.data.CertPem +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) HasKeyPem() bool { + return v.data.KeyPem != nil +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) KeyPem() []byte { + if v.data.KeyPem == nil { + return nil + } + return *v.data.KeyPem +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) HasCaPem() bool { + return v.data.CaPem != nil +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) CaPem() []byte { + if v.data.CaPem == nil { + return nil + } + return *v.data.CaPem +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) HasError() bool { + return v.data.Error != nil +} + +func (v *RunnerRegistrationClientRefreshCertificateResults) Error() string { + if v.data.Error == nil { + return "" + } + return *v.data.Error +} + +func (v RunnerRegistrationClient) RefreshCertificate(ctx context.Context, listen_addr string) (*RunnerRegistrationClientRefreshCertificateResults, error) { + args := RunnerRegistrationRefreshCertificateArgs{} + args.data.ListenAddr = &listen_addr + + var ret runnerRegistrationRefreshCertificateResultsData + + err := v.Call(ctx, "RefreshCertificate", &args, &ret) + if err != nil { + return nil, err + } + + return &RunnerRegistrationClientRefreshCertificateResults{client: v.Client, data: ret}, nil +} diff --git a/cli/commands/runner_start.go b/cli/commands/runner_start.go index fae66fad0..f35253529 100644 --- a/cli/commands/runner_start.go +++ b/cli/commands/runner_start.go @@ -4,6 +4,8 @@ package commands import ( "context" + "crypto/x509" + "encoding/pem" "fmt" "net" "net/netip" @@ -16,6 +18,7 @@ import ( containerd "github.com/containerd/containerd/v2/client" "golang.org/x/sync/errgroup" + "miren.dev/runtime/api/runner/runner_v1alpha" "miren.dev/runtime/clientconfig" containerdcomp "miren.dev/runtime/components/containerd" "miren.dev/runtime/components/netresolve" @@ -23,6 +26,7 @@ import ( "miren.dev/runtime/controllers/sandbox" "miren.dev/runtime/metrics" "miren.dev/runtime/observability" + "miren.dev/runtime/pkg/rpc" "miren.dev/runtime/pkg/runnerconfig" ) @@ -44,6 +48,30 @@ func RunnerStart(ctx *Context, opts struct { "etcd_endpoints", cfg.EtcdEndpoints, "network_backend", cfg.NetworkBackend) + // Determine listen address. If no explicit address is given, discover the + // machine's outbound IP (the one that would route to the coordinator) and + // advertise that so the coordinator knows how to reach this runner. + listenAddr := opts.ListenAddr + if listenAddr == "" { + port := "8444" + ip, err := discoverOutboundIP(cfg.CoordinatorAddress) + if err != nil { + return fmt.Errorf("could not discover outbound IP for listen address (use --listen to set manually): %w", err) + } + listenAddr = net.JoinHostPort(ip.String(), port) + ctx.Log.Info("discovered listen address", "addr", listenAddr) + } + + // The runner's certificate is persisted in the config (often on a disk that + // outlives the VM). If the listen address has changed since the cert was + // issued, the persisted cert no longer covers our IP and clients (e.g. + // sandbox exec) will fail TLS verification. Re-issue the cert before we use + // it to serve. Refreshing is fatal when the address isn't covered: serving a + // stale cert would leave the runner silently unreachable. + if err := ensureRunnerCertificate(ctx, cfg, opts.ConfigPath, listenAddr); err != nil { + return err + } + // Create clientconfig from saved certs for RPC authentication clientCfg := clientconfig.NewConfig() clientCfg.SetCluster("coordinator", &clientconfig.ClusterConfig{ @@ -137,20 +165,6 @@ func RunnerStart(ctx *Context, opts struct { // Create errgroup for background tasks eg, egCtx := errgroup.WithContext(sigCtx) - // Determine listen address. If no explicit address is given, discover the - // machine's outbound IP (the one that would route to the coordinator) and - // advertise that so the coordinator knows how to reach this runner. - listenAddr := opts.ListenAddr - if listenAddr == "" { - port := "8444" - ip, err := discoverOutboundIP(cfg.CoordinatorAddress) - if err != nil { - return fmt.Errorf("could not discover outbound IP for listen address (use --listen to set manually): %w", err) - } - listenAddr = net.JoinHostPort(ip.String(), port) - ctx.Log.Info("discovered listen address", "addr", listenAddr) - } - // Build runner configuration runnerCfg := runner.RunnerConfig{ Id: cfg.RunnerID, @@ -300,3 +314,107 @@ func RunnerStart(ctx *Context, opts struct { ctx.Log.Info("runner stopped") return nil } + +// ensureRunnerCertificate re-issues the runner's server certificate when the +// current listen address is not covered by the persisted certificate's SANs. +// This happens when a VM is recreated with a new IP but a persistent disk keeps +// the old config (cert + runner ID). The refreshed cert is written back to the +// config so it survives subsequent restarts. A required-but-failed refresh is +// fatal: serving a stale cert would leave the runner silently unreachable. +func ensureRunnerCertificate(ctx *Context, cfg *runnerconfig.Config, configPath, listenAddr string) error { + if cfg.ClientCert == "" { + return nil + } + + covered, err := certCoversListenAddr(cfg.ClientCert, listenAddr) + if err != nil { + return fmt.Errorf("failed to inspect runner certificate: %w", err) + } + if covered { + return nil + } + + ctx.Log.Info("runner certificate does not cover listen address; refreshing", + "listen_addr", listenAddr) + + cs, err := rpc.NewState(ctx, + rpc.WithLogger(ctx.Log), + rpc.WithBindAddr("[::]:0"), + rpc.WithCertPEMs([]byte(cfg.ClientCert), []byte(cfg.ClientKey)), + rpc.WithCertificateVerification([]byte(cfg.CACert)), + ) + if err != nil { + return fmt.Errorf("failed to create RPC state for certificate refresh: %w", err) + } + defer cs.Close() + + client, err := cs.Connect(cfg.CoordinatorAddress, rpc.ServiceRunner) + if err != nil { + return fmt.Errorf("failed to connect to coordinator for certificate refresh: %w", err) + } + defer client.Close() + + rc := runner_v1alpha.NewRunnerRegistrationClient(client) + res, err := rc.RefreshCertificate(ctx, listenAddr) + if err != nil { + return fmt.Errorf("certificate refresh request failed: %w", err) + } + if res.Error() != "" { + return fmt.Errorf("certificate refresh rejected by coordinator: %s", res.Error()) + } + if len(res.CertPem()) == 0 || len(res.KeyPem()) == 0 { + return fmt.Errorf("coordinator returned an empty certificate") + } + + cfg.ClientCert = string(res.CertPem()) + cfg.ClientKey = string(res.KeyPem()) + if len(res.CaPem()) > 0 { + cfg.CACert = string(res.CaPem()) + } + + if err := cfg.Save(configPath); err != nil { + return fmt.Errorf("failed to save refreshed certificate: %w", err) + } + + ctx.Log.Info("runner certificate refreshed", "listen_addr", listenAddr) + return nil +} + +// certCoversListenAddr reports whether the leaf certificate in certPEM carries a +// SAN matching the host of listenAddr: an IP SAN for an IP host, or a DNS SAN +// for a hostname. This mirrors how the coordinator builds the certificate's SANs +// from the listen address. +func certCoversListenAddr(certPEM, listenAddr string) (bool, error) { + host, _, err := net.SplitHostPort(listenAddr) + if err != nil { + host = listenAddr + } + if host == "" { + return true, nil + } + + block, _ := pem.Decode([]byte(certPEM)) + if block == nil { + return false, fmt.Errorf("no PEM block found in certificate") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false, fmt.Errorf("parsing certificate: %w", err) + } + + if ip := net.ParseIP(host); ip != nil { + for _, certIP := range cert.IPAddresses { + if certIP.Equal(ip) { + return true, nil + } + } + return false, nil + } + + for _, name := range cert.DNSNames { + if name == host { + return true, nil + } + } + return false, nil +} diff --git a/cli/commands/runner_start_test.go b/cli/commands/runner_start_test.go new file mode 100644 index 000000000..a63d9ee0a --- /dev/null +++ b/cli/commands/runner_start_test.go @@ -0,0 +1,65 @@ +//go:build linux + +package commands + +import ( + "net" + "testing" + "time" + + "miren.dev/runtime/pkg/caauth" +) + +func TestCertCoversListenAddr(t *testing.T) { + ca, err := caauth.New(caauth.Options{ + CommonName: "test-ca", + Organization: "test", + ValidFor: 24 * time.Hour, + }) + if err != nil { + t.Fatalf("failed to create CA: %v", err) + } + + cc, err := ca.IssueCertificate(caauth.Options{ + CommonName: "runner-abc12345", + Organization: "miren", + ValidFor: time.Hour, + IPs: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("10.0.0.45")}, + DNSNames: []string{"localhost", "runner.example.com"}, + }) + if err != nil { + t.Fatalf("failed to issue cert: %v", err) + } + certPEM := string(cc.CertPEM) + + tests := []struct { + name string + listenAddr string + want bool + }{ + {"covered IP", "10.0.0.45:8444", true}, + {"loopback IP", "127.0.0.1:8444", true}, + {"changed IP not covered", "10.0.0.47:8444", false}, + {"covered DNS host", "runner.example.com:8444", true}, + {"uncovered DNS host", "other.example.com:8444", false}, + {"bare IP without port", "10.0.0.45", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := certCoversListenAddr(certPEM, tt.listenAddr) + if err != nil { + t.Fatalf("certCoversListenAddr returned error: %v", err) + } + if got != tt.want { + t.Errorf("certCoversListenAddr(%q) = %v, want %v", tt.listenAddr, got, tt.want) + } + }) + } +} + +func TestCertCoversListenAddrInvalidPEM(t *testing.T) { + if _, err := certCoversListenAddr("not a pem", "10.0.0.1:8444"); err == nil { + t.Fatal("expected error for invalid PEM input") + } +} diff --git a/pkg/caauth/caauth.go b/pkg/caauth/caauth.go index 21f5770fb..888e3b0dd 100644 --- a/pkg/caauth/caauth.go +++ b/pkg/caauth/caauth.go @@ -291,12 +291,17 @@ func (ca *Authority) VerifyCertificate(certPEM []byte) error { return fmt.Errorf("parsing certificate: %w", err) } + return ca.VerifyCert(cert) +} + +// VerifyCert verifies that an already-parsed certificate was signed by this CA. +func (ca *Authority) VerifyCert(cert *x509.Certificate) error { // Create verification pool with CA cert roots := x509.NewCertPool() roots.AddCert(ca.cert) // Verify the certificate - _, err = cert.Verify(x509.VerifyOptions{ + _, err := cert.Verify(x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go index 4bbd918a4..b128d1450 100644 --- a/pkg/rpc/server.go +++ b/pkg/rpc/server.go @@ -968,8 +968,9 @@ func (s *Server) startCallStream(w http.ResponseWriter, r *http.Request) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { call.peer = r.TLS.PeerCertificates[0] - ctx = context.WithValue(ctx, connectionKey{}, &CurrentConnectionInfo{ - PeerSubject: r.TLS.PeerCertificates[0].Subject.String(), + ctx = ContextWithConnectionInfo(ctx, &CurrentConnectionInfo{ + PeerSubject: r.TLS.PeerCertificates[0].Subject.String(), + PeerCertificate: r.TLS.PeerCertificates[0], }) } diff --git a/pkg/rpc/state.go b/pkg/rpc/state.go index 042d381c3..baf500e38 100644 --- a/pkg/rpc/state.go +++ b/pkg/rpc/state.go @@ -431,6 +431,11 @@ type connectionKey struct{} type CurrentConnectionInfo struct { PeerSubject string + // PeerCertificate is the client certificate presented during the mTLS + // handshake, if any. The server is configured with tls.RequestClientCert, + // which requests but does not verify the cert, so handlers that rely on it + // for authorization must verify it (e.g. that it chains to the cluster CA). + PeerCertificate *x509.Certificate } func ConnectionInfo(ctx context.Context) *CurrentConnectionInfo { @@ -442,6 +447,13 @@ func ConnectionInfo(ctx context.Context) *CurrentConnectionInfo { return v.(*CurrentConnectionInfo) } +// ContextWithConnectionInfo returns a copy of ctx carrying the given connection +// info, as observed by ConnectionInfo. The server populates this from the mTLS +// handshake; tests use it to exercise handlers that authorize on the peer cert. +func ContextWithConnectionInfo(ctx context.Context, info *CurrentConnectionInfo) context.Context { + return context.WithValue(ctx, connectionKey{}, info) +} + func (s *State) startListener(ctx context.Context, so *stateOptions) error { err := s.setupServer(so) if err != nil { diff --git a/servers/runner/registration.go b/servers/runner/registration.go index a4f4ced0b..a864004cb 100644 --- a/servers/runner/registration.go +++ b/servers/runner/registration.go @@ -2,9 +2,11 @@ package runner import ( "context" + "crypto/x509" "fmt" "log/slog" "net" + "slices" "strings" "time" @@ -249,22 +251,7 @@ func (s *RegistrationServer) Join(ctx context.Context, req *runner_v1alpha.Runne // so the coordinator can connect to the runner's API by IP. certName := runnerCertName(runnerID) - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("::1"), - } - dnsNames := []string{"localhost"} - - if listenAddr != "" { - host, _, err := net.SplitHostPort(listenAddr) - if err == nil && host != "" { - if ip := net.ParseIP(host); ip != nil { - ips = append(ips, ip) - } else if host != "localhost" { - dnsNames = append(dnsNames, host) - } - } - } + ips, dnsNames := buildRunnerSANs(listenAddr) cc, err := s.Authority.IssueCertificate(caauth.Options{ CommonName: certName, @@ -363,6 +350,132 @@ func (s *RegistrationServer) Join(ctx context.Context, req *runner_v1alpha.Runne return nil } +// buildRunnerSANs returns the IP and DNS subject alternative names a runner's +// server certificate should carry: always loopback (127.0.0.1, ::1, localhost) +// plus the host from the runner's advertised listen address (an IP becomes an +// IP SAN, a hostname becomes a DNS SAN). +func buildRunnerSANs(listenAddr string) ([]net.IP, []string) { + ips := []net.IP{ + net.ParseIP("127.0.0.1"), + net.ParseIP("::1"), + } + dnsNames := []string{"localhost"} + + if listenAddr != "" { + host, _, err := net.SplitHostPort(listenAddr) + if err == nil && host != "" { + if ip := net.ParseIP(host); ip != nil { + ips = append(ips, ip) + } else if host != "localhost" { + dnsNames = append(dnsNames, host) + } + } + } + + return ips, dnsNames +} + +// RefreshCertificate re-issues the calling runner's server certificate with SANs +// derived from its current listen address. A runner needs this when its listen +// IP changes but its persisted certificate (e.g. on a disk that outlives the VM) +// still carries the old IP. The method is public at the RPC layer but authorizes +// the caller here: the presented client certificate must chain to the cluster CA +// and be a runner certificate, and the re-issued certificate keeps that +// certificate's CommonName so a runner can only refresh its own identity. +func (s *RegistrationServer) RefreshCertificate(ctx context.Context, req *runner_v1alpha.RunnerRegistrationRefreshCertificate) error { + args := req.Args() + results := req.Results() + + info := rpc.ConnectionInfo(ctx) + var peer *x509.Certificate + if info != nil { + peer = info.PeerCertificate + } + + listenAddr := "" + if args.HasListenAddr() { + listenAddr = args.ListenAddr() + } + + cc, err := s.reissueRunnerCertificate(ctx, peer, listenAddr) + if err != nil { + results.SetError(err.Error()) + return nil + } + + results.SetCertPem(cc.CertPEM) + results.SetKeyPem(cc.KeyPEM) + results.SetCaPem(cc.CACert) + + return nil +} + +// reissueRunnerCertificate authorizes the caller solely by its presented client +// certificate and, if valid, issues a fresh runner server certificate with SANs +// derived from listenAddr. The new certificate keeps the caller's CommonName so +// a runner can only refresh its own identity. Authorization requires that the +// presented certificate is a CA-signed runner certificate and that a runner +// matching its identity is still registered (so a removed runner's still-valid +// certificate cannot perpetually renew itself). The runner identity is taken +// from the verified certificate, never from caller-supplied input. The returned +// error is safe to surface to the caller. +func (s *RegistrationServer) reissueRunnerCertificate(ctx context.Context, peer *x509.Certificate, listenAddr string) (*caauth.ClientCertificate, error) { + if peer == nil { + return nil, fmt.Errorf("a client certificate is required to refresh a certificate") + } + + if err := s.Authority.VerifyCert(peer); err != nil { + s.Log.Warn("RefreshCertificate rejected: peer cert not signed by cluster CA", + "error", err, "subject", peer.Subject.String()) + return nil, fmt.Errorf("client certificate is not trusted") + } + + commonName := peer.Subject.CommonName + runnerID, ok := strings.CutPrefix(commonName, "runner-") + if !ok || runnerID == "" || !slices.Contains(peer.Subject.Organization, "miren") { + s.Log.Warn("RefreshCertificate rejected: peer cert is not a runner certificate", + "subject", peer.Subject.String()) + return nil, fmt.Errorf("client certificate is not a runner certificate") + } + + // Confirm the runner is still registered. caauth has no revocation, so this + // is what prevents a removed runner's still-valid certificate from renewing + // itself indefinitely. The runner ID is taken from the verified certificate's + // CommonName (which embeds the full ID), so the caller cannot substitute + // another runner's identity. + node, _, err := s.findNodeByQuery(ctx, runnerID) + if err != nil { + s.Log.Error("RefreshCertificate failed to verify runner registration", + "error", err, "subject", peer.Subject.String()) + return nil, fmt.Errorf("failed to verify runner registration") + } + if node == nil { + s.Log.Warn("RefreshCertificate rejected: runner is not registered", + "subject", peer.Subject.String()) + return nil, fmt.Errorf("runner is not registered") + } + + ips, dnsNames := buildRunnerSANs(listenAddr) + + cc, err := s.Authority.IssueCertificate(caauth.Options{ + CommonName: commonName, + Organization: "miren", + ValidFor: 365 * 24 * time.Hour, + IPs: ips, + DNSNames: dnsNames, + }) + if err != nil { + s.Log.Error("Failed to re-issue certificate", "error", err, "common_name", commonName) + return nil, fmt.Errorf("failed to issue certificate") + } + + s.Log.Info("Re-issued runner certificate", + "common_name", commonName, + "listen_addr", listenAddr) + + return cc, nil +} + func (s *RegistrationServer) ListInvites(ctx context.Context, req *runner_v1alpha.RunnerRegistrationListInvites) error { results := req.Results() diff --git a/servers/runner/registration_test.go b/servers/runner/registration_test.go index 1e6c0b614..2c531f837 100644 --- a/servers/runner/registration_test.go +++ b/servers/runner/registration_test.go @@ -6,6 +6,7 @@ import ( "encoding/pem" "net" "testing" + "time" "miren.dev/runtime/api/runner/runner_v1alpha" "miren.dev/runtime/pkg/caauth" @@ -18,6 +19,8 @@ import ( type testEnv struct { client *runner_v1alpha.RunnerRegistrationClient store *entity.MockStore + server *RegistrationServer + ca *caauth.Authority } func newTestServer(t *testing.T) (*testEnv, func()) { @@ -28,6 +31,7 @@ func newTestServer(t *testing.T) (*testEnv, func()) { ca, err := caauth.New(caauth.Options{ CommonName: "test-ca", Organization: "test", + ValidFor: 24 * time.Hour, }) if err != nil { cleanup() @@ -44,7 +48,38 @@ func newTestServer(t *testing.T) (*testEnv, func()) { localClient := rpc.LocalClient(runner_v1alpha.AdaptRunnerRegistration(regServer)) client := runner_v1alpha.NewRunnerRegistrationClient(localClient) - return &testEnv{client: client, store: es.Store}, cleanup + return &testEnv{client: client, store: es.Store, server: regServer, ca: ca}, cleanup +} + +// issueLeafCert issues a certificate from the given authority and returns the +// parsed leaf certificate (the first PEM block; IssueCertificate appends the CA +// cert after it). +func issueLeafCert(t *testing.T, ca *caauth.Authority, commonName, org string, ip string) *x509.Certificate { + t.Helper() + + opts := caauth.Options{ + CommonName: commonName, + Organization: org, + ValidFor: time.Hour, + } + if ip != "" { + opts.IPs = []net.IP{net.ParseIP(ip)} + } + + cc, err := ca.IssueCertificate(opts) + if err != nil { + t.Fatalf("failed to issue cert: %v", err) + } + + block, _ := pem.Decode(cc.CertPEM) + if block == nil { + t.Fatal("failed to decode issued cert PEM") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("failed to parse issued cert: %v", err) + } + return cert } // createInviteAndDecode creates a one-time invite and returns the secret. @@ -468,3 +503,218 @@ func TestRemoveRunnerByRunnerId(t *testing.T) { t.Errorf("expected 0 runners, got %d", len(listResult.Runners())) } } + +func TestBuildRunnerSANs(t *testing.T) { + t.Run("IP listen address adds IP SAN", func(t *testing.T) { + ips, dnsNames := buildRunnerSANs("10.0.0.7:8444") + + wantIPs := []string{"127.0.0.1", "::1", "10.0.0.7"} + for _, want := range wantIPs { + if !containsIP(ips, want) { + t.Errorf("missing IP SAN %s, got %v", want, ips) + } + } + if !containsStr(dnsNames, "localhost") { + t.Errorf("missing DNS SAN localhost, got %v", dnsNames) + } + if containsStr(dnsNames, "10.0.0.7") { + t.Errorf("IP should not appear as a DNS SAN, got %v", dnsNames) + } + }) + + t.Run("hostname listen address adds DNS SAN", func(t *testing.T) { + ips, dnsNames := buildRunnerSANs("runner.example.com:8444") + + if !containsStr(dnsNames, "runner.example.com") { + t.Errorf("missing DNS SAN runner.example.com, got %v", dnsNames) + } + if !containsIP(ips, "127.0.0.1") || !containsIP(ips, "::1") { + t.Errorf("missing loopback IP SANs, got %v", ips) + } + }) + + t.Run("empty listen address yields only loopback", func(t *testing.T) { + ips, dnsNames := buildRunnerSANs("") + if len(ips) != 2 { + t.Errorf("expected only loopback IPs, got %v", ips) + } + if len(dnsNames) != 1 || dnsNames[0] != "localhost" { + t.Errorf("expected only localhost DNS, got %v", dnsNames) + } + }) +} + +// joinRunner performs a Join and returns the issued leaf certificate and the +// assigned runner ID, so tests can exercise refresh as a real, registered runner. +func (e *testEnv) joinRunner(t *testing.T, ctx context.Context, listenAddr, name string) (*x509.Certificate, string) { + t.Helper() + + secret := e.createInviteAndDecode(t, ctx) + res, err := e.client.Join(ctx, secret, "", listenAddr, "v1", nil, name) + if err != nil { + t.Fatalf("Join failed: %v", err) + } + if res.HasError() { + t.Fatalf("Join returned error: %s", res.Error()) + } + + block, _ := pem.Decode(res.CertPem()) + if block == nil { + t.Fatal("failed to decode join cert PEM") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("failed to parse join cert: %v", err) + } + return cert, res.RunnerId() +} + +func TestReissueRunnerCertificate(t *testing.T) { + ctx := context.Background() + env, cleanup := newTestServer(t) + defer cleanup() + + t.Run("happy path re-issues with new IP and preserves CN", func(t *testing.T) { + peer, _ := env.joinRunner(t, ctx, "10.0.0.1:8444", "happy-runner") + + cc, err := env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err != nil { + t.Fatalf("reissueRunnerCertificate failed: %v", err) + } + + // The re-issued cert must be signed by the cluster CA. + if err := env.ca.VerifyCertificate(cc.CertPEM); err != nil { + t.Fatalf("re-issued cert not signed by CA: %v", err) + } + + block, _ := pem.Decode(cc.CertPEM) + if block == nil { + t.Fatal("failed to decode re-issued cert PEM") + } + newCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("failed to parse re-issued cert: %v", err) + } + + if newCert.Subject.CommonName != peer.Subject.CommonName { + t.Errorf("CommonName = %q, want preserved %q", newCert.Subject.CommonName, peer.Subject.CommonName) + } + + var foundIPs []string + for _, ip := range newCert.IPAddresses { + foundIPs = append(foundIPs, ip.String()) + } + if !containsIP(newCert.IPAddresses, "10.0.0.2") { + t.Errorf("re-issued cert missing new IP SAN 10.0.0.2, got %v", foundIPs) + } + }) + + t.Run("nil peer is rejected", func(t *testing.T) { + _, err := env.server.reissueRunnerCertificate(ctx, nil, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for nil peer certificate") + } + }) + + t.Run("cert from another CA is rejected", func(t *testing.T) { + otherCA, err := caauth.New(caauth.Options{CommonName: "other-ca", Organization: "other", ValidFor: 24 * time.Hour}) + if err != nil { + t.Fatalf("failed to create other CA: %v", err) + } + peer := issueLeafCert(t, otherCA, "runner-abc12345", "miren", "10.0.0.1") + + _, err = env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for cert signed by a different CA") + } + }) + + t.Run("non-runner CN is rejected", func(t *testing.T) { + peer := issueLeafCert(t, env.ca, "operator-abc", "miren", "10.0.0.1") + + _, err := env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for non-runner CommonName") + } + }) + + t.Run("wrong organization is rejected", func(t *testing.T) { + peer := issueLeafCert(t, env.ca, "runner-abc12345", "intruder", "10.0.0.1") + + _, err := env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for cert with wrong organization") + } + }) + + t.Run("unregistered runner cert is rejected", func(t *testing.T) { + // A genuine CA-signed runner cert, but no matching Node exists (e.g. the + // runner was never registered or has been removed). The identity comes + // from the cert, so a caller cannot substitute another runner's ID. + peer := issueLeafCert(t, env.ca, "runner-deadbeef", "miren", "10.0.0.1") + + _, err := env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error for a runner cert with no registered node") + } + }) + + t.Run("removed runner cannot refresh", func(t *testing.T) { + peer, runnerID := env.joinRunner(t, ctx, "10.0.0.1:8444", "removed-runner") + + // Remove the runner, deleting its Node entity. + removeRes, err := env.client.RemoveRunner(ctx, runnerID, false) + if err != nil { + t.Fatalf("RemoveRunner failed: %v", err) + } + if removeRes.Error() != "" { + t.Fatalf("RemoveRunner returned error: %s", removeRes.Error()) + } + + // Its certificate is still cryptographically valid, but it is no longer + // registered, so refresh must be rejected. + _, err = env.server.reissueRunnerCertificate(ctx, peer, "10.0.0.2:8444") + if err == nil { + t.Fatal("expected error refreshing a removed runner's certificate") + } + }) + +} + +func TestRefreshCertificateRequiresClientCert(t *testing.T) { + ctx := context.Background() + env, cleanup := newTestServer(t) + defer cleanup() + + // The local RPC client does not carry a TLS peer certificate, so the + // handler must reject the request rather than minting a cert. + res, err := env.client.RefreshCertificate(ctx, "10.0.0.2:8444") + if err != nil { + t.Fatalf("RefreshCertificate RPC failed: %v", err) + } + if res.Error() == "" { + t.Fatal("expected RefreshCertificate to reject a call without a client certificate") + } + if len(res.CertPem()) != 0 { + t.Error("expected no certificate when the call is rejected") + } +} + +func containsIP(ips []net.IP, want string) bool { + w := net.ParseIP(want) + for _, ip := range ips { + if ip.Equal(w) { + return true + } + } + return false +} + +func containsStr(ss []string, want string) bool { + for _, s := range ss { + if s == want { + return true + } + } + return false +}