From 1fec065efb4a0c62af45f60e10acf2a6b651f267 Mon Sep 17 00:00:00 2001 From: Weilu Jia Date: Fri, 15 May 2026 17:20:30 -0700 Subject: [PATCH 1/7] feat(orchestrator): Add draining to orchestrator --- packages/orchestrator/pkg/cfg/model.go | 1 + packages/orchestrator/pkg/factories/run.go | 33 +++- .../orchestrator/pkg/sandbox/fc/process.go | 34 +++- .../pkg/sandbox/network/network.go | 36 +++-- packages/orchestrator/pkg/sandbox/sandbox.go | 18 +-- packages/orchestrator/pkg/server/main.go | 146 +++++++++++++++++- packages/orchestrator/pkg/server/sandboxes.go | 22 ++- packages/orchestrator/pkg/server/utils.go | 54 +++++++ 8 files changed, 305 insertions(+), 39 deletions(-) diff --git a/packages/orchestrator/pkg/cfg/model.go b/packages/orchestrator/pkg/cfg/model.go index ad75bf168e..7a16be1c6f 100644 --- a/packages/orchestrator/pkg/cfg/model.go +++ b/packages/orchestrator/pkg/cfg/model.go @@ -103,6 +103,7 @@ type Config struct { RedisPoolSize int `env:"REDIS_POOL_SIZE" envDefault:"5"` RedisMinIdleConns int `env:"REDIS_MIN_IDLE_CONNS" envDefault:"2"` NBDPoolSize int `env:"NBD_POOL_SIZE" envDefault:"64"` + SandboxDrainTimeout time.Duration `env:"SANDBOX_DRAIN_TIMEOUT" envDefault:"48h"` Services []string `env:"ORCHESTRATOR_SERVICES" envDefault:"orchestrator"` PersistentVolumeMounts map[string]string `env:"PERSISTENT_VOLUME_MOUNTS"` } diff --git a/packages/orchestrator/pkg/factories/run.go b/packages/orchestrator/pkg/factories/run.go index 0a963000df..1c064390b1 100644 --- a/packages/orchestrator/pkg/factories/run.go +++ b/packages/orchestrator/pkg/factories/run.go @@ -575,8 +575,8 @@ func run(config cfg.Config, opts Options) (success bool) { if err != nil { logger.L().Fatal(ctx, "failed to create orchestrator server", zap.Error(err)) } - closers = append(closers, closer{"orchestrator server", func(context.Context) error { - return orchestratorService.Close() + closers = append(closers, closer{"orchestrator server", func(closeCtx context.Context) error { + return orchestratorService.Close(closeCtx) }}) // template manager sandbox logger @@ -782,6 +782,35 @@ func run(config cfg.Config, opts Options) (success bool) { } } + if orchestratorService != nil { + var drainCtx context.Context + var cancelDrain context.CancelFunc + if !config.ForceStop && config.SandboxDrainTimeout > 0 { + drainCtx, cancelDrain = context.WithTimeout(closeCtx, config.SandboxDrainTimeout) + } else { + drainCtx, cancelDrain = context.WithCancel(closeCtx) + cancelDrain() + } + + logger.L().Info(ctx, "Starting sandbox drain phase", + zap.Duration("timeout", config.SandboxDrainTimeout), + zap.Bool("forced", config.ForceStop), + zap.Int("sandbox_count", sandboxes.Count()), + ) + + err := orchestratorService.DrainSandboxes(drainCtx) + cancelDrain() + if err != nil { + logger.L().Warn(ctx, "sandbox drain phase did not complete gracefully; forcing sandbox shutdown", zap.Error(err)) + + forceErr := orchestratorService.ForceStopSandboxes(context.WithoutCancel(ctx)) + if forceErr != nil { + logger.L().Error(ctx, "forced sandbox shutdown failed", zap.Error(forceErr)) + success = false + } + } + } + slices.Reverse(closers) for _, closer := range closers { clog := globalLogger.With(zap.String("service", closer.name), zap.Bool("forced", config.ForceStop)) diff --git a/packages/orchestrator/pkg/sandbox/fc/process.go b/packages/orchestrator/pkg/sandbox/fc/process.go index 0074176a48..d83c3c15e3 100644 --- a/packages/orchestrator/pkg/sandbox/fc/process.go +++ b/packages/orchestrator/pkg/sandbox/fc/process.go @@ -673,7 +673,7 @@ func (p *Process) Stop(ctx context.Context) error { // this function should never fail b/c a previous context was canceled. ctx = context.WithoutCancel(ctx) - err := p.cmd.Process.Signal(syscall.SIGTERM) + err := signalProcessGroup(p.cmd.Process.Pid, syscall.SIGTERM) if err != nil { if errors.Is(err, os.ErrProcessDone) { logger.L().Info(ctx, "fc process already exited", logger.WithSandboxID(p.files.SandboxID)) @@ -688,19 +688,26 @@ func (p *Process) Stop(ctx context.Context) error { select { // Wait 10 sec for the FC process to exit, if it doesn't, send SIGKILL. case <-time.After(10 * time.Second): + select { + case <-p.Exit.Done(): + return + default: + } + // Check process status right before Kill — the pre-SIGTERM status // captured above is 10s stale and no longer useful here. status, stateErr := getProcessStatus(p.cmd.Process.Pid) if errors.Is(stateErr, process.ErrorProcessNotRunning) { - // Process already exited, no need to send SIGKILL. + logger.L().Info(ctx, "fc parent process exited before SIGKILL; skipping process group signal", logger.WithSandboxID(p.files.SandboxID)) + return } else if stateErr != nil { logger.L().Warn(ctx, "failed to get fc process status before SIGKILL", zap.Error(stateErr), logger.WithSandboxID(p.files.SandboxID)) } - err := p.cmd.Process.Kill() + err := signalProcessGroup(p.cmd.Process.Pid, syscall.SIGKILL) if err == nil { - logger.L().Info(ctx, "sent SIGKILL to fc process because it was not responding to SIGTERM for 10 seconds", + logger.L().Info(ctx, "sent SIGKILL to fc process group because it was not responding to SIGTERM for 10 seconds", zap.Strings("status", status), logger.WithSandboxID(p.files.SandboxID), ) @@ -718,6 +725,25 @@ func (p *Process) Stop(ctx context.Context) error { return nil } +func signalProcessGroup(pid int, signal syscall.Signal) error { + if pid <= 0 { + return os.ErrProcessDone + } + + // Firecracker is launched with Setsid, so the process PID is also the process + // group ID. Signal the group so unshare/bash/ip descendants cannot keep the VM + // mount namespace or Firecracker process alive after shutdown. + if err := syscall.Kill(-pid, signal); err != nil { + if errors.Is(err, syscall.ESRCH) { + return os.ErrProcessDone + } + + return err + } + + return nil +} + func (p *Process) Pause(ctx context.Context) error { ctx, childSpan := tracer.Start(ctx, "pause-fc") defer childSpan.End() diff --git a/packages/orchestrator/pkg/sandbox/network/network.go b/packages/orchestrator/pkg/sandbox/network/network.go index 7e39467f83..2164470318 100644 --- a/packages/orchestrator/pkg/sandbox/network/network.go +++ b/packages/orchestrator/pkg/sandbox/network/network.go @@ -328,24 +328,26 @@ func (s *Slot) RemoveNetwork() error { } } - // Delete NFS proxy redirect rule - err = tables.Delete("nat", "PREROUTING", - "--in-interface", s.VethName(), "--protocol", "tcp", - "--destination", s.config.OrchestratorInSandboxIPAddress, "--dport", "2049", - "--jump", "REDIRECT", "--to-port", strconv.Itoa(int(s.config.NFSProxyPort)), - ) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting sandbox NFS proxy redirect rule: %w", err)) - } + if tables != nil { + // Delete NFS proxy redirect rule + err = tables.Delete("nat", "PREROUTING", + "--in-interface", s.VethName(), "--protocol", "tcp", + "--destination", s.config.OrchestratorInSandboxIPAddress, "--dport", "2049", + "--jump", "REDIRECT", "--to-port", strconv.Itoa(int(s.config.NFSProxyPort)), + ) + if err != nil { + errs = append(errs, fmt.Errorf("error deleting sandbox NFS proxy redirect rule: %w", err)) + } - // Delete portmapper redirect rule - err = tables.Delete("nat", "PREROUTING", - "--in-interface", s.VethName(), "--protocol", "tcp", - "--destination", s.config.OrchestratorInSandboxIPAddress, "--dport", "111", - "--jump", "REDIRECT", "--to-port", strconv.Itoa(int(s.config.PortmapperPort)), - ) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting sandbox portmapper redirect rule: %w", err)) + // Delete portmapper redirect rule + err = tables.Delete("nat", "PREROUTING", + "--in-interface", s.VethName(), "--protocol", "tcp", + "--destination", s.config.OrchestratorInSandboxIPAddress, "--dport", "111", + "--jump", "REDIRECT", "--to-port", strconv.Itoa(int(s.config.PortmapperPort)), + ) + if err != nil { + errs = append(errs, fmt.Errorf("error deleting sandbox portmapper redirect rule: %w", err)) + } } err = netns.DeleteNamed(s.NamespaceID()) diff --git a/packages/orchestrator/pkg/sandbox/sandbox.go b/packages/orchestrator/pkg/sandbox/sandbox.go index 1f97046b1c..e26f723186 100644 --- a/packages/orchestrator/pkg/sandbox/sandbox.go +++ b/packages/orchestrator/pkg/sandbox/sandbox.go @@ -989,9 +989,11 @@ func (s *Sandbox) doStop(ctx context.Context) error { errs = append(errs, fmt.Errorf("failed to stop FC: %w", fcStopErr)) } - // The process exited, we can continue with the rest of the cleanup. - // We could use select with ctx.Done() to wait for cancellation, but if the process is not exited the whole cleanup will be in a bad state and will result in unexpected behavior. - <-s.process.Exit.Done() + // The process should exit before the rest of cleanup, but memory shutdown + // must still run if the wait context is canceled so UFFD can exit. + if waitErr := s.process.Exit.WaitWithContext(ctx); waitErr != nil { + errs = append(errs, fmt.Errorf("failed waiting for FC exit: %w", waitErr)) + } uffdStopErr := s.Resources.memory.Stop() if uffdStopErr != nil { @@ -1320,15 +1322,7 @@ func getNetworkSlot( ctx, span := tracer.Start(ctx, "clean network-slot") defer span.End() - // We can run this cleanup asynchronously, as it is not important for the sandbox lifecycle - go func(ctx context.Context) { - returnErr := networkPool.Return(ctx, slot, networkReleased, network.ReturnDelay) - if returnErr != nil { - logger.L().Error(ctx, "failed to return network slot", zap.Error(returnErr)) - } - }(context.WithoutCancel(ctx)) - - return nil + return networkPool.Return(ctx, slot, networkReleased, network.ReturnDelay) }) return slot, nil diff --git a/packages/orchestrator/pkg/server/main.go b/packages/orchestrator/pkg/server/main.go index 5d967a3465..1b042fb595 100644 --- a/packages/orchestrator/pkg/server/main.go +++ b/packages/orchestrator/pkg/server/main.go @@ -4,6 +4,7 @@ package server import ( "context" + "errors" "fmt" "sync" "time" @@ -38,6 +39,8 @@ const uploadedBuildsTTL = 1 * time.Hour // MaxStartingInstancesPerNode feature flag and resize the semaphore. const startingSandboxesLimitRefreshInterval = 30 * time.Second +const sandboxDrainLogInterval = 5 * time.Second + type Server struct { orchestrator.UnimplementedSandboxServiceServer orchestrator.UnimplementedChunkServiceServer @@ -60,6 +63,9 @@ type Server struct { done chan struct{} closeOnce sync.Once + + sandboxStartMu sync.RWMutex + sandboxLifecycleWG sync.WaitGroup } type ServiceConfig struct { @@ -149,16 +155,152 @@ func New(ctx context.Context, cfg ServiceConfig) (*Server, error) { return server, nil } -func (s *Server) Close() error { +func (s *Server) Close(ctx context.Context) error { + s.startDraining(ctx) + s.uploadedBuilds.Stop() + + return nil +} + +func (s *Server) startDraining(ctx context.Context) { s.closeOnce.Do(func() { + logger.L().Info(ctx, "orchestrator server entering sandbox drain mode", + zap.Int("live_sandboxes", s.sandboxFactory.Sandboxes.Count()), + ) close(s.done) }) +} - s.uploadedBuilds.Stop() +func (s *Server) DrainSandboxes(ctx context.Context) error { + s.startDraining(ctx) + if err := s.waitSandboxStarts(ctx); err != nil { + return err + } + + live := s.sandboxFactory.Sandboxes.Count() + logger.L().Info(ctx, "starting graceful sandbox drain", zap.Int("live_sandboxes", live)) + if live == 0 { + logger.L().Info(ctx, "graceful sandbox drain complete", zap.Int("live_sandboxes", live)) + + return s.waitSandboxLifecycles(ctx) + } + + ticker := time.NewTicker(sandboxDrainLogInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + remaining := s.sandboxFactory.Sandboxes.Count() + logger.L().Warn(ctx, "graceful sandbox drain timed out", + zap.Int("remaining_sandboxes", remaining), + zap.Error(ctx.Err()), + ) + + return ctx.Err() + case <-ticker.C: + remaining := s.sandboxFactory.Sandboxes.Count() + logger.L().Info(ctx, "waiting for sandbox drain", + zap.Int("remaining_sandboxes", remaining), + ) + if remaining == 0 { + logger.L().Info(ctx, "graceful sandbox drain complete", zap.Int("live_sandboxes", remaining)) + + return s.waitSandboxLifecycles(ctx) + } + } + } +} + +func (s *Server) ForceStopSandboxes(ctx context.Context) error { + s.startDraining(ctx) + if err := s.waitSandboxStarts(ctx); err != nil { + return err + } + + sandboxes := s.sandboxFactory.Sandboxes.Items() + logger.L().Warn(ctx, "starting forced sandbox shutdown", zap.Int("sandbox_count", len(sandboxes))) + if len(sandboxes) == 0 { + return s.waitSandboxLifecycles(ctx) + } + + var wg sync.WaitGroup + errCh := make(chan error, len(sandboxes)) + + for _, sbx := range sandboxes { + wg.Go(func() { + sbxLog := logger.L().With( + logger.WithSandboxID(sbx.Runtime.SandboxID), + logger.WithLifecycleID(sbx.LifecycleID), + logger.WithSandboxIP(sbx.Slot.HostIPString()), + ) + sbxLog.Warn(ctx, "force stopping sandbox during orchestrator shutdown") + + marked := s.sandboxFactory.Sandboxes.MarkStopping(ctx, sbx.Runtime.SandboxID, sbx.LifecycleID) + if !marked { + sbxLog.Info(ctx, "sandbox was already removed from live map before force stop") + } + + if err := sbx.Stop(ctx); err != nil { + errCh <- fmt.Errorf("stop sandbox %s/%s: %w", sbx.Runtime.SandboxID, sbx.LifecycleID, err) + sbxLog.Error(ctx, "failed to force stop sandbox", zap.Error(err)) + + return + } + + if err := sbx.Close(ctx); err != nil { + errCh <- fmt.Errorf("cleanup sandbox %s/%s: %w", sbx.Runtime.SandboxID, sbx.LifecycleID, err) + sbxLog.Error(ctx, "failed to cleanup sandbox after force stop", zap.Error(err)) + + return + } + + if err := s.proxy.RemoveFromPool(sbx.LifecycleID); err != nil { + sbxLog.Warn(ctx, "failed to remove sandbox from proxy pool after force stop", zap.Error(err)) + } + + sbxLog.Info(ctx, "forced sandbox shutdown complete") + }) + } + + wg.Wait() + close(errCh) + + var errs []error + for err := range errCh { + errs = append(errs, err) + } + + if err := s.waitSandboxLifecycles(ctx); err != nil { + errs = append(errs, err) + } + + if err := errors.Join(errs...); err != nil { + logger.L().Error(ctx, "forced sandbox shutdown finished with errors", zap.Error(err)) + + return err + } + + logger.L().Info(ctx, "forced sandbox shutdown complete") return nil } +func (s *Server) waitSandboxLifecycles(ctx context.Context) error { + done := make(chan struct{}) + go func() { + s.sandboxLifecycleWG.Wait() + close(done) + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("waiting for sandbox lifecycle cleanup: %w", ctx.Err()) + case <-done: + return nil + } +} + func (s *Server) refreshStartingSandboxesLimit(ctx context.Context) { ticker := time.NewTicker(startingSandboxesLimitRefreshInterval) defer ticker.Stop() diff --git a/packages/orchestrator/pkg/server/sandboxes.go b/packages/orchestrator/pkg/server/sandboxes.go index f0ad1901cc..6489827ff8 100644 --- a/packages/orchestrator/pkg/server/sandboxes.go +++ b/packages/orchestrator/pkg/server/sandboxes.go @@ -89,6 +89,11 @@ func (s *Server) Create(ctx context.Context, req *orchestrator.SandboxCreateRequ telemetry.WithEnvdVersion(req.GetSandbox().GetEnvdVersion()), ) + if err := s.enterSandboxStart(ctx, "sandbox-create"); err != nil { + return nil, err + } + defer s.leaveSandboxStart() + // setup launch darkly ctx = featureflags.AddToContext( ctx, @@ -203,6 +208,10 @@ func (s *Server) Create(ctx context.Context, req *orchestrator.SandboxCreateRequ SandboxType: sandbox.SandboxTypeSandbox, } + if err := s.rejectIfDraining(ctx, "sandbox-create-before-start"); err != nil { + return nil, err + } + sbx, err := s.sandboxFactory.ResumeSandbox( ctx, template, @@ -561,6 +570,11 @@ func (s *Server) Checkpoint(ctx context.Context, in *orchestrator.SandboxCheckpo Build(), ) + if err := s.enterSandboxStart(ctx, "sandbox-checkpoint"); err != nil { + return nil, err + } + defer s.leaveSandboxStart() + sbx, ok := s.sandboxFactory.Sandboxes.Get(in.GetSandboxId()) if !ok { telemetry.ReportCriticalError(ctx, "sandbox not found", nil, telemetry.WithSandboxID(in.GetSandboxId())) @@ -613,6 +627,10 @@ func (s *Server) Checkpoint(ctx context.Context, in *orchestrator.SandboxCheckpo // the API, routing catalog, and analytics) but with a fresh LifecycleID // so the old sandbox's cleanup goroutine won't // accidentally evict the resumed sandbox from the map. + if err := s.rejectIfDraining(ctx, "sandbox-checkpoint-before-resume"); err != nil { + return nil, err + } + resumedSbx, err := s.sandboxFactory.ResumeSandbox( ctx, template, @@ -834,7 +852,7 @@ func (s *Server) uploadSnapshotAsync(ctx context.Context, sbx *sandbox.Sandbox, // setupSandboxLifecycle sets up the cleanup goroutine for a sandbox. func (s *Server) setupSandboxLifecycle(ctx context.Context, sbx *sandbox.Sandbox) { - go func() { + s.sandboxLifecycleWG.Go(func() { ctx, childSpan := tracer.Start(context.WithoutCancel(ctx), "stop sandbox-lifecycle", trace.WithNewRoot()) defer childSpan.End() @@ -854,7 +872,7 @@ func (s *Server) setupSandboxLifecycle(ctx context.Context, sbx *sandbox.Sandbox } sbxlogger.E(sbx).Info(ctx, "Sandbox stopped") - }() + }) } // stopSandboxAsync stops the sandbox in a background goroutine. diff --git a/packages/orchestrator/pkg/server/utils.go b/packages/orchestrator/pkg/server/utils.go index 3ef314b3e6..dcb13735f1 100644 --- a/packages/orchestrator/pkg/server/utils.go +++ b/packages/orchestrator/pkg/server/utils.go @@ -4,14 +4,68 @@ package server import ( "context" + "fmt" + "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" ) +func (s *Server) rejectIfDraining(ctx context.Context, operation string) error { + select { + case <-s.done: + logger.L().Info(ctx, "rejecting sandbox operation during orchestrator drain", zap.String("operation", operation)) + + return status.Error(codes.Unavailable, "orchestrator is draining") + default: + return nil + } +} + +func (s *Server) enterSandboxStart(ctx context.Context, operation string) error { + s.sandboxStartMu.RLock() + if err := s.rejectIfDraining(ctx, operation); err != nil { + s.sandboxStartMu.RUnlock() + + return err + } + + return nil +} + +func (s *Server) leaveSandboxStart() { + s.sandboxStartMu.RUnlock() +} + +func (s *Server) waitSandboxStarts(ctx context.Context) error { + logger.L().Info(ctx, "waiting for in-flight sandbox start operations to finish") + + done := make(chan struct{}) + go func() { + s.sandboxStartMu.Lock() + logger.L().Info(ctx, "in-flight sandbox start gate acquired") + s.sandboxStartMu.Unlock() + close(done) + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("waiting for in-flight sandbox start operations: %w", ctx.Err()) + case <-done: + logger.L().Info(ctx, "in-flight sandbox start operations finished") + + return nil + } +} + func (s *Server) waitForAcquire(ctx context.Context) error { + if err := s.rejectIfDraining(ctx, "wait-for-acquire"); err != nil { + return err + } + ctx, cancel := context.WithTimeout(ctx, acquireTimeout) defer cancel() From ab8ab6d0a2ab00d12c472a71cf62fc45251ef5c9 Mon Sep 17 00:00:00 2001 From: Weilu Jia Date: Mon, 18 May 2026 11:10:25 -0700 Subject: [PATCH 2/7] fix(orchestrator): track sandbox lifecycles for shutdown --- packages/orchestrator/pkg/sandbox/map.go | 99 +++++++++++++++++-- packages/orchestrator/pkg/sandbox/map_test.go | 65 ++++++++++++ packages/orchestrator/pkg/server/main.go | 2 +- packages/orchestrator/pkg/server/sandboxes.go | 7 +- 4 files changed, 162 insertions(+), 11 deletions(-) create mode 100644 packages/orchestrator/pkg/sandbox/map_test.go diff --git a/packages/orchestrator/pkg/sandbox/map.go b/packages/orchestrator/pkg/sandbox/map.go index b14aa209c0..127db037f8 100644 --- a/packages/orchestrator/pkg/sandbox/map.go +++ b/packages/orchestrator/pkg/sandbox/map.go @@ -9,6 +9,8 @@ import ( "net" "sync" + "go.uber.org/zap" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" "github.com/e2b-dev/infra/packages/shared/pkg/smap" ) @@ -25,14 +27,27 @@ type MapSubscriber interface { OnNetworkRelease(ctx context.Context, sbx *Sandbox) } -// Map holds sandboxes that are live (running) together with a IP-to-sandbox index -// The two maps are managed independently. +type SandboxState string + +const ( + SandboxStateRunning SandboxState = "running" + SandboxStateStopping SandboxState = "stopping" +) + +type lifecycleEntry struct { + sandbox *Sandbox + state SandboxState +} + +// Map holds sandboxes that are live (running), known active lifecycles, +// together with a IP-to-sandbox index. The indexes are managed independently. // // AssignNetwork/NetworkReleased manage the IP map, // MarkRunning/MarkStopping manage the live set. type Map struct { - live *smap.Map[*Sandbox] - network *smap.Map[*Sandbox] + live *smap.Map[*Sandbox] + lifecycles *smap.Map[lifecycleEntry] + network *smap.Map[*Sandbox] subs []MapSubscriber subsLock sync.RWMutex @@ -40,11 +55,16 @@ type Map struct { func NewSandboxesMap() *Map { return &Map{ - live: smap.New[*Sandbox](), - network: smap.New[*Sandbox](), + live: smap.New[*Sandbox](), + lifecycles: smap.New[lifecycleEntry](), + network: smap.New[*Sandbox](), } } +func sandboxLifecycleKey(sandboxID, lifecycleID string) string { + return fmt.Sprintf("%s/%s", sandboxID, lifecycleID) +} + func (m *Map) Subscribe(subscriber MapSubscriber) { m.subsLock.Lock() defer m.subsLock.Unlock() @@ -73,6 +93,35 @@ func (m *Map) Get(sandboxID string) (*Sandbox, bool) { return m.live.Get(sandboxID) } +func (m *Map) LifecycleItems() []*Sandbox { + entries := m.lifecycles.Items() + sandboxes := make([]*Sandbox, 0, len(entries)) + for _, entry := range entries { + sandboxes = append(sandboxes, entry.sandbox) + } + + return sandboxes +} + +func (m *Map) LifecycleItemsByState(states ...SandboxState) []*Sandbox { + stateSet := make(map[SandboxState]struct{}, len(states)) + for _, state := range states { + stateSet[state] = struct{}{} + } + + entries := m.lifecycles.Items() + sandboxes := make([]*Sandbox, 0, len(entries)) + for _, entry := range entries { + if _, ok := stateSet[entry.state]; !ok { + continue + } + + sandboxes = append(sandboxes, entry.sandbox) + } + + return sandboxes +} + // GetByHostPort looks up a sandbox by its host IP address parsed from hostPort. func (m *Map) GetByHostPort(hostPort string) (*Sandbox, error) { reqIP, _, err := net.SplitHostPort(hostPort) @@ -100,8 +149,24 @@ func (m *Map) AssignNetwork(ctx context.Context, sbx *Sandbox) { ) } +func (m *Map) TrackLifecycle(ctx context.Context, sbx *Sandbox, state SandboxState) { + m.lifecycles.Insert(sandboxLifecycleKey(sbx.Runtime.SandboxID, sbx.LifecycleID), lifecycleEntry{ + sandbox: sbx, + state: state, + }) + + logger.L().Info(ctx, "sandbox lifecycle tracked", + logger.WithSandboxID(sbx.Runtime.SandboxID), + logger.WithLifecycleID(sbx.LifecycleID), + logger.WithSandboxIP(sbx.Slot.HostIPString()), + zap.String("state", string(state)), + ) +} + // MarkRunning makes the sandbox visible to Get/Items/Count and notifies OnInsert subscribers. func (m *Map) MarkRunning(ctx context.Context, sbx *Sandbox) { + m.markLifecycleState(sbx.Runtime.SandboxID, sbx.LifecycleID, SandboxStateRunning) + if !m.live.InsertIfAbsent(sbx.Runtime.SandboxID, sbx) { return } @@ -126,6 +191,7 @@ func (m *Map) MarkRunning(ctx context.Context, sbx *Sandbox) { // Returns true if the sandbox was successfully removed. func (m *Map) MarkStopping(ctx context.Context, sandboxID, lifecycleID string) bool { stopped := false + m.markLifecycleState(sandboxID, lifecycleID, SandboxStateStopping) m.live.RemoveCb(sandboxID, func(_ string, sbx *Sandbox, exists bool) bool { if !exists { @@ -150,6 +216,27 @@ func (m *Map) MarkStopping(ctx context.Context, sandboxID, lifecycleID string) b return stopped } +func (m *Map) MarkStopped(ctx context.Context, sbx *Sandbox) { + m.lifecycles.Remove(sandboxLifecycleKey(sbx.Runtime.SandboxID, sbx.LifecycleID)) + + logger.L().Info(ctx, "sandbox lifecycle stopped", + logger.WithSandboxID(sbx.Runtime.SandboxID), + logger.WithLifecycleID(sbx.LifecycleID), + logger.WithSandboxIP(sbx.Slot.HostIPString()), + ) +} + +func (m *Map) markLifecycleState(sandboxID, lifecycleID string, state SandboxState) { + key := sandboxLifecycleKey(sandboxID, lifecycleID) + entry, ok := m.lifecycles.Get(key) + if !ok { + return + } + + entry.state = state + m.lifecycles.Insert(key, entry) +} + // NetworkReleased unregisters a sandbox's IP and notifies OnNetworkRelease // subscribers after a successful removal. func (m *Map) NetworkReleased(ctx context.Context, ip string) { diff --git a/packages/orchestrator/pkg/sandbox/map_test.go b/packages/orchestrator/pkg/sandbox/map_test.go new file mode 100644 index 0000000000..4f5ddb5065 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/map_test.go @@ -0,0 +1,65 @@ +//go:build linux + +package sandbox + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/network" +) + +func TestMapLifecycleItemsRemainAfterMarkStopping(t *testing.T) { + t.Parallel() + + sandboxes := NewSandboxesMap() + sbx := testMapSandbox(t, "sandbox-1", "lifecycle-1") + + sandboxes.TrackLifecycle(t.Context(), sbx, SandboxStateRunning) + sandboxes.MarkRunning(t.Context(), sbx) + require.Len(t, sandboxes.Items(), 1) + require.Len(t, sandboxes.LifecycleItemsByState(SandboxStateRunning), 1) + + marked := sandboxes.MarkStopping(t.Context(), sbx.Runtime.SandboxID, sbx.LifecycleID) + require.True(t, marked) + require.Empty(t, sandboxes.Items()) + require.Len(t, sandboxes.LifecycleItems(), 1) + require.Len(t, sandboxes.LifecycleItemsByState(SandboxStateStopping), 1) + + sandboxes.MarkStopped(t.Context(), sbx) + require.Empty(t, sandboxes.LifecycleItems()) +} + +func TestMapLifecycleItemsAllowDuplicateSandboxIDs(t *testing.T) { + t.Parallel() + + sandboxes := NewSandboxesMap() + oldSbx := testMapSandbox(t, "sandbox-1", "lifecycle-old") + newSbx := testMapSandbox(t, "sandbox-1", "lifecycle-new") + + sandboxes.TrackLifecycle(t.Context(), oldSbx, SandboxStateStopping) + sandboxes.TrackLifecycle(t.Context(), newSbx, SandboxStateRunning) + + require.Len(t, sandboxes.LifecycleItems(), 2) + require.Len(t, sandboxes.LifecycleItemsByState(SandboxStateStopping), 1) + require.Len(t, sandboxes.LifecycleItemsByState(SandboxStateRunning), 1) +} + +func testMapSandbox(t *testing.T, sandboxID, lifecycleID string) *Sandbox { + t.Helper() + + slot, err := network.NewSlot("test", 1, network.Config{}, network.NoopEgressProxy{}) + require.NoError(t, err) + + return &Sandbox{ + LifecycleID: lifecycleID, + Metadata: &Metadata{ + Config: NewConfig(Config{}), + Runtime: RuntimeMetadata{ + SandboxID: sandboxID, + }, + }, + Resources: &Resources{Slot: slot}, + } +} diff --git a/packages/orchestrator/pkg/server/main.go b/packages/orchestrator/pkg/server/main.go index 1b042fb595..23c6858ea9 100644 --- a/packages/orchestrator/pkg/server/main.go +++ b/packages/orchestrator/pkg/server/main.go @@ -218,7 +218,7 @@ func (s *Server) ForceStopSandboxes(ctx context.Context) error { return err } - sandboxes := s.sandboxFactory.Sandboxes.Items() + sandboxes := s.sandboxFactory.Sandboxes.LifecycleItems() logger.L().Warn(ctx, "starting forced sandbox shutdown", zap.Int("sandbox_count", len(sandboxes))) if len(sandboxes) == 0 { return s.waitSandboxLifecycles(ctx) diff --git a/packages/orchestrator/pkg/server/sandboxes.go b/packages/orchestrator/pkg/server/sandboxes.go index 6489827ff8..4af239cd94 100644 --- a/packages/orchestrator/pkg/server/sandboxes.go +++ b/packages/orchestrator/pkg/server/sandboxes.go @@ -627,10 +627,6 @@ func (s *Server) Checkpoint(ctx context.Context, in *orchestrator.SandboxCheckpo // the API, routing catalog, and analytics) but with a fresh LifecycleID // so the old sandbox's cleanup goroutine won't // accidentally evict the resumed sandbox from the map. - if err := s.rejectIfDraining(ctx, "sandbox-checkpoint-before-resume"); err != nil { - return nil, err - } - resumedSbx, err := s.sandboxFactory.ResumeSandbox( ctx, template, @@ -852,7 +848,10 @@ func (s *Server) uploadSnapshotAsync(ctx context.Context, sbx *sandbox.Sandbox, // setupSandboxLifecycle sets up the cleanup goroutine for a sandbox. func (s *Server) setupSandboxLifecycle(ctx context.Context, sbx *sandbox.Sandbox) { + s.sandboxFactory.Sandboxes.TrackLifecycle(ctx, sbx, sandbox.SandboxStateRunning) s.sandboxLifecycleWG.Go(func() { + defer s.sandboxFactory.Sandboxes.MarkStopped(context.WithoutCancel(ctx), sbx) + ctx, childSpan := tracer.Start(context.WithoutCancel(ctx), "stop sandbox-lifecycle", trace.WithNewRoot()) defer childSpan.End() From ed0101af176ed1ce4cbb3c3ef0395bf721c84297 Mon Sep 17 00:00:00 2001 From: Weilu Jia Date: Tue, 19 May 2026 11:31:29 -0700 Subject: [PATCH 3/7] fix(orchestrator): ignore expected shutdown errors --- packages/orchestrator/pkg/factories/run.go | 24 ++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/packages/orchestrator/pkg/factories/run.go b/packages/orchestrator/pkg/factories/run.go index 1c064390b1..e2e51b568d 100644 --- a/packages/orchestrator/pkg/factories/run.go +++ b/packages/orchestrator/pkg/factories/run.go @@ -119,6 +119,16 @@ func (e serviceDoneError) Error() string { return fmt.Sprintf("service %s finished", e.name) } +func isServiceDoneError(err error) bool { + var sde serviceDoneError + + return errors.As(err, &sde) +} + +func isIgnorableSyncError(err error) bool { + return errors.Is(err, syscall.EINVAL) +} + // Run starts the orchestrator, blocking until shutdown. // Returns true on clean shutdown. func Run(opts Options) bool { @@ -234,7 +244,7 @@ func run(config cfg.Config, opts Options) (success bool) { // there's a panic. defer func(g *errgroup.Group) { err := g.Wait() - if err != nil { + if err != nil && !isServiceDoneError(err) { log.Printf("error while shutting down: %v", err) success = false } @@ -275,7 +285,7 @@ func run(config cfg.Config, opts Options) (success bool) { })) defer func(l logger.Logger) { err := l.Sync() - if err != nil { + if err != nil && !isIgnorableSyncError(err) { log.Printf("error while shutting down logger: %v", err) success = false } @@ -293,7 +303,7 @@ func run(config cfg.Config, opts Options) (success bool) { ) defer func(l logger.Logger) { err := l.Sync() - if err != nil { + if err != nil && !isIgnorableSyncError(err) { log.Printf("error while shutting down sandbox logger: %v", err) success = false } @@ -311,7 +321,7 @@ func run(config cfg.Config, opts Options) (success bool) { ) defer func(l logger.Logger) { err := l.Sync() - if err != nil { + if err != nil && !isIgnorableSyncError(err) { log.Printf("error while shutting down sandbox logger: %v", err) success = false } @@ -591,8 +601,7 @@ func run(config cfg.Config, opts Options) (success bool) { ) closers = append(closers, closer{ "template manager sandbox logger", func(context.Context) error { - // Sync returns EINVAL when path is /dev/stdout (for example) - if err := tmplSbxLoggerExternal.Sync(); err != nil && !errors.Is(err, syscall.EINVAL) { + if err := tmplSbxLoggerExternal.Sync(); err != nil && !isIgnorableSyncError(err) { return err } @@ -822,8 +831,7 @@ func run(config cfg.Config, opts Options) (success bool) { } logger.L().Info(ctx, "Waiting for services to finish") - var sde serviceDoneError - if err := g.Wait(); err != nil && !errors.As(err, &sde) { + if err := g.Wait(); err != nil && !isServiceDoneError(err) { logger.L().Error(ctx, "service group error", zap.Error(err)) success = false } From 5618fe8741cffe005b966b250693459e0eb3de1d Mon Sep 17 00:00:00 2001 From: Weilu Jia Date: Tue, 19 May 2026 13:55:33 -0700 Subject: [PATCH 4/7] feat(orchestrator): add lifecycle manager --- .../orchestrator/pkg/lifecycle/manager.go | 138 ++++++++++++++ .../pkg/lifecycle/manager_test.go | 170 ++++++++++++++++++ 2 files changed, 308 insertions(+) create mode 100644 packages/orchestrator/pkg/lifecycle/manager.go create mode 100644 packages/orchestrator/pkg/lifecycle/manager_test.go diff --git a/packages/orchestrator/pkg/lifecycle/manager.go b/packages/orchestrator/pkg/lifecycle/manager.go new file mode 100644 index 0000000000..f40f221a13 --- /dev/null +++ b/packages/orchestrator/pkg/lifecycle/manager.go @@ -0,0 +1,138 @@ +package lifecycle + +import ( + "context" + "errors" + "fmt" +) + +type Hook func(context.Context) error + +type Unit struct { + Name string + After []string + Start Hook + Stop Hook +} + +type Manager struct { + units []Unit + names map[string]struct{} +} + +func NewManager() *Manager { + return &Manager{names: make(map[string]struct{})} +} + +func (m *Manager) Register(units ...Unit) error { + if m.names == nil { + m.names = make(map[string]struct{}) + } + + for _, unit := range units { + if unit.Name == "" { + return errors.New("lifecycle unit name is required") + } + + if _, ok := m.names[unit.Name]; ok { + return fmt.Errorf("duplicate lifecycle unit %q", unit.Name) + } + + m.names[unit.Name] = struct{}{} + unit.After = append([]string(nil), unit.After...) + m.units = append(m.units, unit) + } + + return nil +} + +func (m *Manager) Start(ctx context.Context) error { + order, err := m.startOrder() + if err != nil { + return err + } + + for _, unit := range order { + if unit.Start == nil { + continue + } + + if err := unit.Start(ctx); err != nil { + return fmt.Errorf("start %q: %w", unit.Name, err) + } + } + + return nil +} + +func (m *Manager) Stop(ctx context.Context) error { + order, err := m.startOrder() + if err != nil { + return err + } + + var errs []error + for i := len(order) - 1; i >= 0; i-- { + unit := order[i] + if unit.Stop == nil { + continue + } + + if err := unit.Stop(ctx); err != nil { + errs = append(errs, fmt.Errorf("stop %q: %w", unit.Name, err)) + } + } + + return errors.Join(errs...) +} + +func (m *Manager) startOrder() ([]Unit, error) { + units := make(map[string]Unit, len(m.units)) + for _, unit := range m.units { + units[unit.Name] = unit + } + + for _, unit := range m.units { + for _, dependency := range unit.After { + if _, ok := units[dependency]; !ok { + return nil, fmt.Errorf("lifecycle unit %q depends on unknown unit %q", unit.Name, dependency) + } + } + } + + permanent := make(map[string]struct{}, len(m.units)) + temporary := make(map[string]struct{}, len(m.units)) + order := make([]Unit, 0, len(m.units)) + + var visit func(Unit) error + visit = func(unit Unit) error { + if _, ok := permanent[unit.Name]; ok { + return nil + } + + if _, ok := temporary[unit.Name]; ok { + return fmt.Errorf("lifecycle dependency cycle includes unit %q", unit.Name) + } + + temporary[unit.Name] = struct{}{} + for _, dependency := range unit.After { + if err := visit(units[dependency]); err != nil { + return err + } + } + + delete(temporary, unit.Name) + permanent[unit.Name] = struct{}{} + order = append(order, unit) + + return nil + } + + for _, unit := range m.units { + if err := visit(unit); err != nil { + return nil, err + } + } + + return order, nil +} diff --git a/packages/orchestrator/pkg/lifecycle/manager_test.go b/packages/orchestrator/pkg/lifecycle/manager_test.go new file mode 100644 index 0000000000..2aea8060c0 --- /dev/null +++ b/packages/orchestrator/pkg/lifecycle/manager_test.go @@ -0,0 +1,170 @@ +package lifecycle + +import ( + "context" + "errors" + "reflect" + "strings" + "testing" +) + +func TestStartRunsHooksInDependencyOrder(t *testing.T) { + t.Parallel() + + ctx := context.Background() + manager := NewManager() + var calls []string + + err := manager.Register( + Unit{Name: "api", After: []string{"database"}, Start: recordHook(&calls, "api")}, + Unit{Name: "cache"}, + Unit{Name: "database", Start: recordHook(&calls, "database")}, + ) + if err != nil { + t.Fatalf("register units: %v", err) + } + + if err := manager.Start(ctx); err != nil { + t.Fatalf("start units: %v", err) + } + + want := []string{"database", "api"} + if !reflect.DeepEqual(calls, want) { + t.Fatalf("calls = %v, want %v", calls, want) + } +} + +func TestStopRunsHooksInReverseDependencyOrder(t *testing.T) { + t.Parallel() + + ctx := context.Background() + manager := NewManager() + var calls []string + + err := manager.Register( + Unit{Name: "api", After: []string{"database"}, Stop: recordHook(&calls, "api")}, + Unit{Name: "database", Stop: recordHook(&calls, "database")}, + Unit{Name: "metrics"}, + ) + if err != nil { + t.Fatalf("register units: %v", err) + } + + if err := manager.Stop(ctx); err != nil { + t.Fatalf("stop units: %v", err) + } + + want := []string{"api", "database"} + if !reflect.DeepEqual(calls, want) { + t.Fatalf("calls = %v, want %v", calls, want) + } +} + +func TestRegisterRejectsDuplicateName(t *testing.T) { + t.Parallel() + + manager := NewManager() + err := manager.Register(Unit{Name: "api"}, Unit{Name: "api"}) + if err == nil { + t.Fatal("expected duplicate name error") + } + + if !strings.Contains(err.Error(), "duplicate") || !strings.Contains(err.Error(), "api") { + t.Fatalf("error = %q, want duplicate api context", err.Error()) + } +} + +func TestStartRejectsUnknownDependency(t *testing.T) { + t.Parallel() + + manager := NewManager() + err := manager.Register(Unit{Name: "api", After: []string{"database"}}) + if err != nil { + t.Fatalf("register units: %v", err) + } + + err = manager.Start(context.Background()) + if err == nil { + t.Fatal("expected unknown dependency error") + } + + if !strings.Contains(err.Error(), "api") || !strings.Contains(err.Error(), "database") { + t.Fatalf("error = %q, want dependent and missing unit names", err.Error()) + } +} + +func TestStartRejectsDependencyCycle(t *testing.T) { + t.Parallel() + + manager := NewManager() + err := manager.Register( + Unit{Name: "api", After: []string{"worker"}}, + Unit{Name: "worker", After: []string{"api"}}, + ) + if err != nil { + t.Fatalf("register units: %v", err) + } + + err = manager.Start(context.Background()) + if err == nil { + t.Fatal("expected cycle error") + } + + if !strings.Contains(err.Error(), "cycle") { + t.Fatalf("error = %q, want cycle context", err.Error()) + } +} + +func TestStopContinuesAfterErrors(t *testing.T) { + t.Parallel() + + ctx := context.Background() + manager := NewManager() + firstErr := errors.New("first failed") + secondErr := errors.New("second failed") + var calls []string + + err := manager.Register( + Unit{Name: "first", Stop: func(context.Context) error { + calls = append(calls, "first") + + return firstErr + }}, + Unit{Name: "second", After: []string{"first"}, Stop: func(context.Context) error { + calls = append(calls, "second") + + return secondErr + }}, + Unit{Name: "third", After: []string{"second"}, Stop: recordHook(&calls, "third")}, + ) + if err != nil { + t.Fatalf("register units: %v", err) + } + + err = manager.Stop(ctx) + if err == nil { + t.Fatal("expected joined stop error") + } + + want := []string{"third", "second", "first"} + if !reflect.DeepEqual(calls, want) { + t.Fatalf("calls = %v, want %v", calls, want) + } + + if !errors.Is(err, firstErr) || !errors.Is(err, secondErr) { + t.Fatalf("error = %v, want joined first and second errors", err) + } + + message := err.Error() + if !strings.Contains(message, "stop \"first\"") || !strings.Contains(message, "stop \"second\"") { + t.Fatalf("error = %q, want component context", message) + } +} + +func recordHook(calls *[]string, name string) Hook { + return func(context.Context) error { + *calls = append(*calls, name) + + return nil + } +} From 3c3f9cc9c9212c33908f618ea5b5d2a3f1295922 Mon Sep 17 00:00:00 2001 From: Weilu Jia Date: Tue, 19 May 2026 14:04:39 -0700 Subject: [PATCH 5/7] refactor(orchestrator): use lifecycle stop ordering --- packages/orchestrator/pkg/factories/run.go | 137 ++++++++++++--------- 1 file changed, 81 insertions(+), 56 deletions(-) diff --git a/packages/orchestrator/pkg/factories/run.go b/packages/orchestrator/pkg/factories/run.go index e2e51b568d..db39f029b4 100644 --- a/packages/orchestrator/pkg/factories/run.go +++ b/packages/orchestrator/pkg/factories/run.go @@ -35,6 +35,7 @@ import ( "github.com/e2b-dev/infra/packages/orchestrator/pkg/events" e2bhealthcheck "github.com/e2b-dev/infra/packages/orchestrator/pkg/healthcheck" "github.com/e2b-dev/infra/packages/orchestrator/pkg/hyperloopserver" + "github.com/e2b-dev/infra/packages/orchestrator/pkg/lifecycle" "github.com/e2b-dev/infra/packages/orchestrator/pkg/localupload" "github.com/e2b-dev/infra/packages/orchestrator/pkg/metrics" "github.com/e2b-dev/infra/packages/orchestrator/pkg/nfsproxy" @@ -106,11 +107,6 @@ type Options struct { EgressFactory EgressFactory } -type closer struct { - name string - close func(ctx context.Context) error -} - type serviceDoneError struct { name string } @@ -356,7 +352,30 @@ func run(config cfg.Config, opts Options) (success bool) { }) } - var closers []closer + shutdown := lifecycle.NewManager() + registerUnit := func(unit lifecycle.Unit) { + if err := shutdown.Register(unit); err != nil { + logger.L().Fatal(ctx, "failed to register lifecycle unit", zap.String("unit", unit.Name), zap.Error(err)) + } + } + registerClose := func(name string, after []string, closeFn func(context.Context) error) { + registerUnit(lifecycle.Unit{ + Name: name, + After: after, + Stop: func(closeCtx context.Context) error { + clog := globalLogger.With(zap.String("service", name), zap.Bool("forced", config.ForceStop)) + clog.Info(ctx, "closing") + + if err := closeFn(closeCtx); err != nil { + clog.Error(ctx, "error during shutdown", zap.Error(err)) + + return err + } + + return nil + }, + }) + } // The sandbox map is shared between the server and the proxy // to propagate information about sandbox routing. @@ -367,7 +386,7 @@ func run(config cfg.Config, opts Options) (success bool) { if err != nil { logger.L().Fatal(ctx, "failed to create feature flags client", zap.Error(err)) } - closers = append(closers, closer{"feature flags", featureFlags.Close}) + registerClose("feature flags", nil, featureFlags.Close) featureFlags.SetDeploymentName(config.DomainName) @@ -376,7 +395,7 @@ func run(config cfg.Config, opts Options) (success bool) { if err != nil { logger.L().Fatal(ctx, "failed to create limiter", zap.Error(err)) } - closers = append(closers, closer{"limiter", limiter.Close}) + registerClose("limiter", []string{"feature flags"}, limiter.Close) persistence, err := storage.GetStorageProvider(ctx, storage.TemplateStorageConfig.WithLimiter(limiter)) if err != nil { @@ -399,9 +418,9 @@ func run(config cfg.Config, opts Options) (success bool) { if err != nil && !errors.Is(err, sharedFactories.ErrRedisDisabled) { logger.L().Fatal(ctx, "Could not connect to Redis", zap.Error(err)) } else if err == nil { - closers = append(closers, closer{"redis client", func(context.Context) error { + registerClose("redis client", nil, func(context.Context) error { return sharedFactories.CloseCleanly(redisClient) - }}) + }) } peerRegistry := peerclient.NopRegistry() @@ -416,11 +435,11 @@ func run(config cfg.Config, opts Options) (success bool) { logger.L().Fatal(ctx, "failed to create template cache", zap.Error(err)) } templateCache.Start(ctx) - closers = append(closers, closer{"template cache", func(context.Context) error { + registerClose("template cache", []string{"feature flags", "limiter"}, func(context.Context) error { templateCache.Stop() return nil - }}) + }) sbxEventsDeliveryTargets := make([]event.Delivery[event.SandboxEvent], 0) @@ -432,9 +451,9 @@ func run(config cfg.Config, opts Options) (success bool) { if err != nil { logger.L().Fatal(ctx, "failed to create clickhouse driver", zap.Error(err)) } - closers = append(closers, closer{"clickhouse connection", func(context.Context) error { + registerClose("clickhouse connection", nil, func(context.Context) error { return clickhouseConn.Close() - }}) + }) sbxEventsDeliveryClickhouse, err := clickhouseevents.NewDefaultClickhouseSandboxEventsDelivery(ctx, clickhouseConn, featureFlags) if err != nil { @@ -442,7 +461,7 @@ func run(config cfg.Config, opts Options) (success bool) { } sbxEventsDeliveryTargets = append(sbxEventsDeliveryTargets, sbxEventsDeliveryClickhouse) - closers = append(closers, closer{"sandbox events delivery for clickhouse", sbxEventsDeliveryClickhouse.Close}) + registerClose("sandbox events delivery for clickhouse", []string{"clickhouse connection", "feature flags"}, sbxEventsDeliveryClickhouse.Close) hostStatsDeliveryClickhouse, err := clickhousehoststats.NewDefaultClickhouseHostStatsDelivery(ctx, clickhouseConn, featureFlags) if err != nil { @@ -450,7 +469,7 @@ func run(config cfg.Config, opts Options) (success bool) { } hostStatsDelivery = hostStatsDeliveryClickhouse - closers = append(closers, closer{"sandbox host stats delivery", hostStatsDeliveryClickhouse.Close}) + registerClose("sandbox host stats delivery", []string{"clickhouse connection", "feature flags"}, hostStatsDeliveryClickhouse.Close) } // cgroup manager for resource accounting @@ -469,7 +488,7 @@ func run(config cfg.Config, opts Options) (success bool) { if redisClient != nil { sbxEventsDeliveryRedis := event.NewRedisStreamsDelivery[event.SandboxEvent](redisClient, event.SandboxEventsStreamName) sbxEventsDeliveryTargets = append(sbxEventsDeliveryTargets, sbxEventsDeliveryRedis) - closers = append(closers, closer{"sandbox events delivery for redis", sbxEventsDeliveryRedis.Close}) + registerClose("sandbox events delivery for redis", []string{"redis client"}, sbxEventsDeliveryRedis.Close) } // sandbox observer @@ -477,7 +496,7 @@ func run(config cfg.Config, opts Options) (success bool) { if err != nil { logger.L().Fatal(ctx, "failed to create sandbox observer", zap.Error(err)) } - closers = append(closers, closer{"sandbox observer", sandboxObserver.Close}) + registerClose("sandbox observer", nil, sandboxObserver.Close) // host metrics — samples CPU in the background so GetCPUMetrics is a // non-blocking cache read on the request path. @@ -485,7 +504,7 @@ func run(config cfg.Config, opts Options) (success bool) { startService("host metrics poller", func() error { return hostMetrics.Start() }) - closers = append(closers, closer{"host metrics poller", hostMetrics.Close}) + registerClose("host metrics poller", nil, hostMetrics.Close) // sandbox proxy sandboxProxy, err := proxy.NewSandboxProxy(tel.MeterProvider, config.ProxyPort, sandboxes, featureFlags) @@ -500,7 +519,7 @@ func run(config cfg.Config, opts Options) (success bool) { return err }) - closers = append(closers, closer{"sandbox proxy", sandboxProxy.Close}) + registerClose("sandbox proxy", []string{"feature flags"}, sandboxProxy.Close) // egress proxy — built by the edition-specific factory deps := &Deps{ @@ -525,7 +544,9 @@ func run(config cfg.Config, opts Options) (success bool) { }) } if egressSetup.Close != nil { - closers = append(closers, closer{"egress proxy", egressSetup.Close}) + registerClose("egress proxy", []string{"feature flags"}, egressSetup.Close) + } else { + registerUnit(lifecycle.Unit{Name: "egress proxy", After: []string{"feature flags"}}) } // device pool @@ -538,7 +559,7 @@ func run(config cfg.Config, opts Options) (success bool) { return nil }) - closers = append(closers, closer{"device pool", devicePool.Close}) + registerClose("device pool", nil, devicePool.Close) // network pool slotStorage, err := newStorage(ctx, nodeID, config.NetworkConfig, egressSetup.Proxy) @@ -551,7 +572,7 @@ func run(config cfg.Config, opts Options) (success bool) { return nil }) - closers = append(closers, closer{"network pool", networkPool.Close}) + registerClose("network pool", []string{"device pool", "egress proxy"}, networkPool.Close) // sandbox factory sandboxFactory := sandbox.NewFactory(config.BuilderConfig, networkPool, devicePool, featureFlags, hostStatsDelivery, cgroupManager, egressSetup.Proxy, sandboxes) @@ -561,11 +582,11 @@ func run(config cfg.Config, opts Options) (success bool) { volumeService := volumes.New(config, builder) uploads := sandbox.NewUploads(templateCache, persistence, peerResolver, redisClient) - closers = append(closers, closer{"pending uploads", func(context.Context) error { + registerClose("pending uploads", []string{"template cache"}, func(context.Context) error { uploads.Stop() return nil - }}) + }) orchestratorService, err := server.New(ctx, server.ServiceConfig{ Config: config, @@ -585,9 +606,9 @@ func run(config cfg.Config, opts Options) (success bool) { if err != nil { logger.L().Fatal(ctx, "failed to create orchestrator server", zap.Error(err)) } - closers = append(closers, closer{"orchestrator server", func(closeCtx context.Context) error { + registerClose("orchestrator server", []string{"network pool", "device pool", "template cache", "pending uploads", "sandbox proxy"}, func(closeCtx context.Context) error { return orchestratorService.Close(closeCtx) - }}) + }) // template manager sandbox logger tmplSbxLoggerExternal := sbxlogger.NewLogger( @@ -599,23 +620,29 @@ func run(config cfg.Config, opts Options) (success bool) { CollectorAddress: env.LogsCollectorAddress(), }, ) - closers = append(closers, closer{ - "template manager sandbox logger", func(context.Context) error { + registerClose( + "template manager sandbox logger", nil, func(context.Context) error { if err := tmplSbxLoggerExternal.Sync(); err != nil && !isIgnorableSyncError(err) { return err } return nil }, - }) + ) // nfs proxy server if len(config.PersistentVolumeMounts) > 0 { - nfsClosers, err := startNFSProxy(ctx, config, builder, startService, sandboxes) + nfsUnits, err := startNFSProxy(ctx, config, builder, startService, sandboxes) if err != nil { logger.L().Fatal(ctx, "failed to start nfs proxy", zap.Error(err)) } - closers = append(closers, nfsClosers...) + for _, unit := range nfsUnits { + if unit.Stop == nil { + registerUnit(unit) + } else { + registerClose(unit.Name, unit.After, unit.Stop) + } + } } // hyperloop server @@ -631,7 +658,7 @@ func run(config cfg.Config, opts Options) (success bool) { return err }) - closers = append(closers, closer{"hyperloop server", hyperloopSrv.Shutdown}) + registerClose("hyperloop server", nil, hyperloopSrv.Shutdown) grpcServer := e2bgrpc.NewGRPCServer(tel, e2bgrpc.WithSandboxResumeMetrics()) orchestrator.RegisterSandboxServiceServer(grpcServer, orchestratorService) @@ -669,7 +696,7 @@ func run(config cfg.Config, opts Options) (success bool) { templatemanager.RegisterTemplateServiceServer(grpcServer, tmpl) - closers = append(closers, closer{"template server", tmpl.Close}) + registerClose("template server", []string{"orchestrator server", "network pool", "device pool", "template cache", "sandbox proxy"}, tmpl.Close) } infoService := service.NewInfoService(serviceInfo, sandboxes, hostMetrics) @@ -698,12 +725,12 @@ func run(config cfg.Config, opts Options) (success bool) { return err }) - closers = append(closers, closer{"cmux server", func(context.Context) error { + registerClose("cmux server", nil, func(context.Context) error { logger.L().Info(ctx, "Shutting down cmux server") cmuxServer.Close() return nil - }}) + }) pprofServer := telemetry.NewPprofServer() // We handle the pprof in a separate goroutine to prevent any interaction with the main server. @@ -714,7 +741,7 @@ func run(config cfg.Config, opts Options) (success bool) { logger.L().Error(ctx, "pprof server encountered error", zap.Error(err)) } }() - closers = append(closers, closer{"pprof server", pprofServer.Shutdown}) + registerClose("pprof server", nil, pprofServer.Shutdown) // http server healthcheck, err := e2bhealthcheck.NewHealthcheck(serviceInfo) @@ -743,18 +770,18 @@ func run(config cfg.Config, opts Options) (success bool) { return err } }) - closers = append(closers, closer{"http server", httpServer.Shutdown}) + registerClose("http server", []string{"cmux server"}, httpServer.Shutdown) // grpc server startService("grpc server", func() error { return grpcServer.Serve(grpcListener) }) - closers = append(closers, closer{"grpc server", func(context.Context) error { + registerClose("grpc server", []string{"cmux server", "orchestrator server"}, func(context.Context) error { logger.L().Info(ctx, "Shutting down grpc server") grpcServer.GracefulStop() return nil - }}) + }) // Wait for the shutdown signal or if some service fails select { @@ -791,7 +818,7 @@ func run(config cfg.Config, opts Options) (success bool) { } } - if orchestratorService != nil { + registerClose("sandbox drain", []string{"orchestrator server", "network pool", "device pool"}, func(context.Context) error { var drainCtx context.Context var cancelDrain context.CancelFunc if !config.ForceStop && config.SandboxDrainTimeout > 0 { @@ -818,16 +845,13 @@ func run(config cfg.Config, opts Options) (success bool) { success = false } } - } - slices.Reverse(closers) - for _, closer := range closers { - clog := globalLogger.With(zap.String("service", closer.name), zap.Bool("forced", config.ForceStop)) - clog.Info(ctx, "closing") - if err := closer.close(closeCtx); err != nil { - clog.Error(ctx, "error during shutdown", zap.Error(err)) - success = false - } + return nil + }) + + if err := shutdown.Stop(closeCtx); err != nil { + logger.L().Error(ctx, "error during lifecycle shutdown", zap.Error(err)) + success = false } logger.L().Info(ctx, "Waiting for services to finish") @@ -845,8 +869,8 @@ func startNFSProxy( builder *chrooted.Builder, startService func(name string, f func() error), sandboxes *sandbox.Map, -) ([]closer, error) { - var closers []closer +) ([]lifecycle.Unit, error) { + var units []lifecycle.Unit // portmapper listener var pmConfig net.ListenConfig @@ -861,7 +885,7 @@ func startNFSProxy( startService("portmapper server", func() error { return pm.Serve(ctx, pmLis) }) - closers = append(closers, closer{"portmapper server", func(_ context.Context) error { return pmLis.Close() }}) + units = append(units, lifecycle.Unit{Name: "portmapper server", Stop: func(_ context.Context) error { return pmLis.Close() }}) // nfs proxy listener var nfsConfig net.ListenConfig @@ -885,13 +909,14 @@ func startNFSProxy( startService("nfs proxy", func() error { return nfsServer.Serve(lis) }) - closers = append(closers, closer{ - "nfs proxy server", func(_ context.Context) error { + units = append(units, lifecycle.Unit{ + Name: "nfs proxy server", + Stop: func(_ context.Context) error { return lis.Close() }, }) - return closers, nil + return units, nil } func setupBuildStorage(ctx context.Context, limiter *limit.Limiter, orchConfig cfg.Config) (storage.StorageProvider, *localupload.Handler, error) { From ca1012c0861bb9d5a3f0b96815716b3c84405826 Mon Sep 17 00:00:00 2001 From: Weilu Jia Date: Tue, 19 May 2026 16:39:11 -0700 Subject: [PATCH 6/7] fix(orchestrator): clean partial network setup --- .../pkg/sandbox/network/network.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/packages/orchestrator/pkg/sandbox/network/network.go b/packages/orchestrator/pkg/sandbox/network/network.go index 2164470318..137827193e 100644 --- a/packages/orchestrator/pkg/sandbox/network/network.go +++ b/packages/orchestrator/pkg/sandbox/network/network.go @@ -18,7 +18,7 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) -func (s *Slot) CreateNetwork(ctx context.Context) error { +func (s *Slot) CreateNetwork(ctx context.Context) (retErr error) { // Prevent thread changes so we can safely manipulate with namespaces runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -29,10 +29,19 @@ func (s *Slot) CreateNetwork(ctx context.Context) error { return fmt.Errorf("cannot get current (host) namespace: %w", err) } + cleanupNeeded := false defer func() { - err = netns.Set(hostNS) - if err != nil { - logger.L().Error(ctx, "error resetting network namespace back to the host namespace", zap.Error(err)) + restoreErr := netns.Set(hostNS) + if restoreErr != nil { + logger.L().Error(ctx, "error resetting network namespace back to the host namespace", zap.Error(restoreErr)) + } + + if retErr != nil && cleanupNeeded { + if restoreErr != nil { + retErr = errors.Join(retErr, fmt.Errorf("error resetting network namespace back to the host namespace before cleanup: %w", restoreErr)) + } else if cleanupErr := s.RemoveNetwork(); cleanupErr != nil { + retErr = errors.Join(retErr, fmt.Errorf("error cleaning up partially created network: %w", cleanupErr)) + } } err = hostNS.Close() @@ -46,6 +55,7 @@ func (s *Slot) CreateNetwork(ctx context.Context) error { if err != nil { return fmt.Errorf("cannot create new namespace: %w", err) } + cleanupNeeded = true defer ns.Close() From 6443f22c3b64df064f435365e48a0f53d0194a72 Mon Sep 17 00:00:00 2001 From: Weilu Jia Date: Tue, 19 May 2026 18:25:11 -0700 Subject: [PATCH 7/7] fix(orchestrator): harden network teardown --- .../pkg/sandbox/network/firewall.go | 12 ++- .../pkg/sandbox/network/network.go | 99 +++++++++++++------ .../pkg/sandbox/network/network_test.go | 49 +++++++++ 3 files changed, 128 insertions(+), 32 deletions(-) create mode 100644 packages/orchestrator/pkg/sandbox/network/network_test.go diff --git a/packages/orchestrator/pkg/sandbox/network/firewall.go b/packages/orchestrator/pkg/sandbox/network/firewall.go index 5889667fdc..08d9d5f35b 100644 --- a/packages/orchestrator/pkg/sandbox/network/firewall.go +++ b/packages/orchestrator/pkg/sandbox/network/firewall.go @@ -3,6 +3,7 @@ package network import ( + "errors" "fmt" "net/netip" "slices" @@ -109,7 +110,16 @@ func NewFirewall(tapIf string, orchestratorInternalIP string, extraAllowedCIDRs } func (fw *Firewall) Close() error { - return fw.conn.CloseLasting() + fw.conn.DelTable(&nftables.Table{ + Name: tableName, + Family: nftables.TableFamilyINet, + }) + deleteErr := fw.conn.Flush() + if errors.Is(deleteErr, unix.ENOENT) { + deleteErr = nil + } + + return errors.Join(deleteErr, fw.conn.CloseLasting()) } // tapIfaceMatch returns expressions that match packets from the tap interface. diff --git a/packages/orchestrator/pkg/sandbox/network/network.go b/packages/orchestrator/pkg/sandbox/network/network.go index 137827193e..14a386f9b7 100644 --- a/packages/orchestrator/pkg/sandbox/network/network.go +++ b/packages/orchestrator/pkg/sandbox/network/network.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "net" + "os" "runtime" "strconv" @@ -14,10 +15,66 @@ import ( "github.com/vishvananda/netlink" "github.com/vishvananda/netns" "go.uber.org/zap" + "golang.org/x/sys/unix" "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) +type notExistError interface { + IsNotExist() bool +} + +type multiUnwrapError interface { + Unwrap() []error +} + +func ignoreExpectedAbsent(err error, isExpected func(error) bool) bool { + if err == nil { + return true + } + + var joined multiUnwrapError + if errors.As(err, &joined) { + for _, child := range joined.Unwrap() { + if !ignoreExpectedAbsent(child, isExpected) { + return false + } + } + + return true + } + + return isExpected(err) +} + +func isIPTablesNotExist(err error) bool { + var notExist notExistError + + return errors.As(err, ¬Exist) && notExist.IsNotExist() +} + +func isRouteNotExist(err error) bool { + return errors.Is(err, unix.ESRCH) || errors.Is(err, unix.ENOENT) +} + +func isLinkNotExist(err error) bool { + var linkNotFound netlink.LinkNotFoundError + + return errors.As(err, &linkNotFound) || errors.Is(err, unix.ENODEV) || errors.Is(err, unix.ENOENT) +} + +func isNamespaceNotExist(err error) bool { + return os.IsNotExist(err) || errors.Is(err, unix.ENOENT) +} + +func appendUnlessExpectedAbsentf(errs *[]error, err error, isExpected func(error) bool, format string) { + if ignoreExpectedAbsent(err, isExpected) { + return + } + + *errs = append(*errs, fmt.Errorf(format, err)) +} + func (s *Slot) CreateNetwork(ctx context.Context) (retErr error) { // Prevent thread changes so we can safely manipulate with namespaces runtime.LockOSThread() @@ -283,20 +340,14 @@ func (s *Slot) RemoveNetwork() error { } else { // Delete host forwarding rules err = tables.Delete("filter", "FORWARD", "-i", s.VethName(), "-o", defaultGateway, "-j", "ACCEPT") - if err != nil { - errs = append(errs, fmt.Errorf("error deleting host forwarding rule to default gateway: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isIPTablesNotExist, "error deleting host forwarding rule to default gateway: %w") err = tables.Delete("filter", "FORWARD", "-i", defaultGateway, "-o", s.VethName(), "-j", "ACCEPT") - if err != nil { - errs = append(errs, fmt.Errorf("error deleting host forwarding rule from default gateway: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isIPTablesNotExist, "error deleting host forwarding rule from default gateway: %w") // Delete host postrouting rules err = tables.Delete("nat", "POSTROUTING", "-s", s.HostCIDR(), "-o", defaultGateway, "-j", "MASQUERADE") - if err != nil { - errs = append(errs, fmt.Errorf("error deleting host postrouting rule: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isIPTablesNotExist, "error deleting host postrouting rule: %w") // Delete hyperloop proxy redirect rule err = tables.Delete( @@ -304,15 +355,11 @@ func (s *Slot) RemoveNetwork() error { "-p", "tcp", "-d", s.config.OrchestratorInSandboxIPAddress, "--dport", "80", "-j", "REDIRECT", "--to-port", s.hyperloopPort, ) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting sandbox hyperloop proxy redirect rule: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isIPTablesNotExist, "error deleting sandbox hyperloop proxy redirect rule: %w") // Delete changes made by egress proxy err = s.egressProxy.OnSlotDelete(s, tables) - if err != nil { - errs = append(errs, err) - } + appendUnlessExpectedAbsentf(&errs, err, isIPTablesNotExist, "%w") } // Delete routing from host to FC namespace @@ -320,9 +367,7 @@ func (s *Slot) RemoveNetwork() error { Gw: s.VpeerIP(), Dst: s.HostNet(), }) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting route from host to FC: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isRouteNotExist, "error deleting route from host to FC: %w") // Delete veth device // We explicitly delete the veth device from the host namespace because even though deleting @@ -330,12 +375,10 @@ func (s *Slot) RemoveNetwork() error { // the same name immediately after deleting the namespace. veth, err := netlink.LinkByName(s.VethName()) if err != nil { - errs = append(errs, fmt.Errorf("error finding veth: %w", err)) + appendUnlessExpectedAbsentf(&errs, err, isLinkNotExist, "error finding veth: %w") } else { err = netlink.LinkDel(veth) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting veth device: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isLinkNotExist, "error deleting veth device: %w") } if tables != nil { @@ -345,9 +388,7 @@ func (s *Slot) RemoveNetwork() error { "--destination", s.config.OrchestratorInSandboxIPAddress, "--dport", "2049", "--jump", "REDIRECT", "--to-port", strconv.Itoa(int(s.config.NFSProxyPort)), ) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting sandbox NFS proxy redirect rule: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isIPTablesNotExist, "error deleting sandbox NFS proxy redirect rule: %w") // Delete portmapper redirect rule err = tables.Delete("nat", "PREROUTING", @@ -355,15 +396,11 @@ func (s *Slot) RemoveNetwork() error { "--destination", s.config.OrchestratorInSandboxIPAddress, "--dport", "111", "--jump", "REDIRECT", "--to-port", strconv.Itoa(int(s.config.PortmapperPort)), ) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting sandbox portmapper redirect rule: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isIPTablesNotExist, "error deleting sandbox portmapper redirect rule: %w") } err = netns.DeleteNamed(s.NamespaceID()) - if err != nil { - errs = append(errs, fmt.Errorf("error deleting namespace: %w", err)) - } + appendUnlessExpectedAbsentf(&errs, err, isNamespaceNotExist, "error deleting namespace: %w") return errors.Join(errs...) } diff --git a/packages/orchestrator/pkg/sandbox/network/network_test.go b/packages/orchestrator/pkg/sandbox/network/network_test.go new file mode 100644 index 0000000000..cfd582c3b2 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/network/network_test.go @@ -0,0 +1,49 @@ +//go:build linux + +package network + +import ( + "errors" + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +type fakeNotExistError struct { + notExist bool +} + +func (e fakeNotExistError) Error() string { + return "fake iptables error" +} + +func (e fakeNotExistError) IsNotExist() bool { + return e.notExist +} + +func TestIgnoreExpectedAbsentHandlesWrappedAndJoinedErrors(t *testing.T) { + t.Parallel() + + wrapped := fmt.Errorf("wrapped: %w", fakeNotExistError{notExist: true}) + joined := errors.Join(wrapped, fakeNotExistError{notExist: true}) + + require.True(t, ignoreExpectedAbsent(joined, isIPTablesNotExist)) + require.False(t, ignoreExpectedAbsent(errors.Join(joined, errors.New("boom")), isIPTablesNotExist)) + require.False(t, ignoreExpectedAbsent(fakeNotExistError{notExist: false}, isIPTablesNotExist)) +} + +func TestExpectedAbsentClassifiers(t *testing.T) { + t.Parallel() + + require.True(t, isRouteNotExist(fmt.Errorf("route delete failed: %w", unix.ESRCH))) + require.False(t, isRouteNotExist(unix.EPERM)) + + require.True(t, isLinkNotExist(fmt.Errorf("link delete failed: %w", unix.ENODEV))) + require.False(t, isLinkNotExist(unix.EPERM)) + + require.True(t, isNamespaceNotExist(&os.PathError{Op: "remove", Path: "/var/run/netns/missing", Err: unix.ENOENT})) + require.False(t, isNamespaceNotExist(unix.EPERM)) +}