From 0df9b5892d3797c678bcfef2f0b8e98a9674f676 Mon Sep 17 00:00:00 2001 From: enxebre Date: Wed, 21 Jan 2026 15:19:35 +0100 Subject: [PATCH] Add node drain condition for sqs --- .../rebalance-recommendation-event.go | 22 ++ pkg/monitor/sqsevent/spot-itn-event.go | 22 ++ pkg/node/node.go | 153 ++++++++++++ pkg/node/node_test.go | 219 ++++++++++++++++++ 4 files changed, 416 insertions(+) diff --git a/pkg/monitor/sqsevent/rebalance-recommendation-event.go b/pkg/monitor/sqsevent/rebalance-recommendation-event.go index baa0aad2..0eacb1c3 100644 --- a/pkg/monitor/sqsevent/rebalance-recommendation-event.go +++ b/pkg/monitor/sqsevent/rebalance-recommendation-event.go @@ -71,6 +71,23 @@ func (m SQSMonitor) rebalanceRecommendationToInterruptionEvent(event *EventBridg Description: fmt.Sprintf("Rebalance recommendation event received. Instance %s will be cordoned at %s \n", rebalanceRecDetail.InstanceID, event.getTime()), } interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error { + // Use provider ID to resolve the actual Kubernetes node name if UseProviderId is configured + nthConfig := n.GetNthConfig() + nodeName := interruptionEvent.NodeName + if nthConfig.UseProviderId && interruptionEvent.ProviderID != "" { + resolvedNodeName, err := n.GetNodeNameFromProviderID(interruptionEvent.ProviderID) + if err != nil { + log.Warn().Err(err).Str("provider_id", interruptionEvent.ProviderID).Msg("Failed to resolve node name from provider ID, falling back to NodeName from event") + } else { + nodeName = resolvedNodeName + } + } + + // Remove the draining condition from the node + if err := n.RemoveDrainingCondition(nodeName); err != nil { + log.Err(err).Str("node_name", nodeName).Msg("Unable to remove draining condition from node") + } + errs := m.deleteMessages([]*sqs.Message{message}) if errs != nil { return errs[0] @@ -90,6 +107,11 @@ func (m SQSMonitor) rebalanceRecommendationToInterruptionEvent(event *EventBridg } } + // Set the draining condition on the node + if err := n.SetDrainingCondition(nodeName, "RebalanceRecommendation", interruptionEvent.Description); err != nil { + log.Err(err).Str("node_name", nodeName).Msg("Unable to set draining condition on node") + } + err := n.TaintRebalanceRecommendation(nodeName, interruptionEvent.EventID) if err != nil { log.Err(err).Msgf("Unable to taint node with taint %s:%s", node.RebalanceRecommendationTaint, interruptionEvent.EventID) diff --git a/pkg/monitor/sqsevent/spot-itn-event.go b/pkg/monitor/sqsevent/spot-itn-event.go index a5f21e1e..2449d972 100644 --- a/pkg/monitor/sqsevent/spot-itn-event.go +++ b/pkg/monitor/sqsevent/spot-itn-event.go @@ -73,6 +73,23 @@ func (m SQSMonitor) spotITNTerminationToInterruptionEvent(event *EventBridgeEven Description: fmt.Sprintf("Spot Interruption notice for instance %s was sent at %s \n", spotInterruptionDetail.InstanceID, event.getTime()), } interruptionEvent.PostDrainTask = func(interruptionEvent monitor.InterruptionEvent, n node.Node) error { + // Use provider ID to resolve the actual Kubernetes node name if UseProviderId is configured + nthConfig := n.GetNthConfig() + nodeName := interruptionEvent.NodeName + if nthConfig.UseProviderId && interruptionEvent.ProviderID != "" { + resolvedNodeName, err := n.GetNodeNameFromProviderID(interruptionEvent.ProviderID) + if err != nil { + log.Warn().Err(err).Str("provider_id", interruptionEvent.ProviderID).Msg("Failed to resolve node name from provider ID, falling back to NodeName from event") + } else { + nodeName = resolvedNodeName + } + } + + // Remove the draining condition from the node + if err := n.RemoveDrainingCondition(nodeName); err != nil { + log.Err(err).Str("node_name", nodeName).Msg("Unable to remove draining condition from node") + } + errs := m.deleteMessages([]*sqs.Message{message}) if errs != nil { return errs[0] @@ -92,6 +109,11 @@ func (m SQSMonitor) spotITNTerminationToInterruptionEvent(event *EventBridgeEven } } + // Set the draining condition on the node + if err := n.SetDrainingCondition(nodeName, "SpotInterruption", interruptionEvent.Description); err != nil { + log.Err(err).Str("node_name", nodeName).Msg("Unable to set draining condition on node") + } + err := n.TaintSpotItn(nodeName, interruptionEvent.EventID) if err != nil { log.Err(err).Msgf("Unable to taint node with taint %s:%s", node.SpotInterruptionTaint, interruptionEvent.EventID) diff --git a/pkg/node/node.go b/pkg/node/node.go index b80d6ebf..58cde993 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -51,6 +51,11 @@ const ( ExcludeFromLoadBalancersLabelValue = "aws-node-termination-handler" ) +const ( + // TerminationHandlerDrainingConditionType is a node condition type indicating the node is being drained + TerminationHandlerDrainingConditionType = "TerminationHandlerDraining" +) + const ( // SpotInterruptionTaint is a taint used to make spot instance unschedulable SpotInterruptionTaint = "aws-node-termination-handler/spot-itn" @@ -583,6 +588,154 @@ func (n Node) RemoveNTHTaints(nodeName string) error { return nil } +// SetDrainingCondition adds a condition to the node indicating it is being drained by NTH +func (n Node) SetDrainingCondition(nodeName string, reason string, message string) error { + if n.nthConfig.DryRun { + log.Info().Str("node_name", nodeName).Str("reason", reason).Msg("Would have set draining condition on node, but dry-run flag was set") + return nil + } + + k8sNode, err := n.fetchKubernetesNode(nodeName) + if err != nil { + return fmt.Errorf("Unable to fetch kubernetes node from API: %w", err) + } + + return n.setNodeCondition(k8sNode, TerminationHandlerDrainingConditionType, corev1.ConditionTrue, reason, message) +} + +// RemoveDrainingCondition removes the draining condition from the node +func (n Node) RemoveDrainingCondition(nodeName string) error { + if n.nthConfig.DryRun { + log.Info().Str("node_name", nodeName).Msg("Would have removed draining condition from node, but dry-run flag was set") + return nil + } + + k8sNode, err := n.fetchKubernetesNode(nodeName) + if err != nil { + return fmt.Errorf("Unable to fetch kubernetes node from API: %w", err) + } + + return n.removeNodeCondition(k8sNode, TerminationHandlerDrainingConditionType) +} + +// setNodeCondition adds or updates a condition on the node +func (n Node) setNodeCondition(node *corev1.Node, conditionType corev1.NodeConditionType, status corev1.ConditionStatus, reason string, message string) error { + retryDeadline := time.Now().Add(maxRetryDeadline) + freshNode := node.DeepCopy() + client := n.drainHelper.Client + var err error + refresh := false + + for { + if refresh { + freshNode, err = client.CoreV1().Nodes().Get(context.TODO(), node.Name, metav1.GetOptions{}) + if err != nil || freshNode == nil { + return fmt.Errorf("failed to get node %v: %w", node.Name, err) + } + } + + now := metav1.Now() + newCondition := corev1.NodeCondition{ + Type: conditionType, + Status: status, + Reason: reason, + Message: message, + LastTransitionTime: now, + LastHeartbeatTime: now, + } + + conditionExists := false + for i, condition := range freshNode.Status.Conditions { + if condition.Type == conditionType { + freshNode.Status.Conditions[i] = newCondition + conditionExists = true + break + } + } + if !conditionExists { + freshNode.Status.Conditions = append(freshNode.Status.Conditions, newCondition) + } + + _, err = client.CoreV1().Nodes().UpdateStatus(context.TODO(), freshNode, metav1.UpdateOptions{}) + if err != nil && errors.IsConflict(err) && time.Now().Before(retryDeadline) { + refresh = true + time.Sleep(conflictRetryInterval) + continue + } + + if err != nil { + log.Err(err). + Str("condition_type", string(conditionType)). + Str("node_name", node.Name). + Msg("Error while setting condition on node") + return err + } + log.Info(). + Str("condition_type", string(conditionType)). + Str("reason", reason). + Str("node_name", node.Name). + Msg("Successfully set condition on node") + return nil + } +} + +// removeNodeCondition removes a condition from the node +func (n Node) removeNodeCondition(node *corev1.Node, conditionType corev1.NodeConditionType) error { + retryDeadline := time.Now().Add(maxRetryDeadline) + freshNode := node.DeepCopy() + client := n.drainHelper.Client + var err error + refresh := false + + for { + if refresh { + freshNode, err = client.CoreV1().Nodes().Get(context.TODO(), node.Name, metav1.GetOptions{}) + if err != nil || freshNode == nil { + return fmt.Errorf("failed to get node %v: %w", node.Name, err) + } + } + + newConditions := make([]corev1.NodeCondition, 0) + found := false + for _, condition := range freshNode.Status.Conditions { + if condition.Type == conditionType { + found = true + continue + } + newConditions = append(newConditions, condition) + } + + if !found { + if !refresh { + refresh = true + continue + } + return nil + } + + freshNode.Status.Conditions = newConditions + _, err = client.CoreV1().Nodes().UpdateStatus(context.TODO(), freshNode, metav1.UpdateOptions{}) + if err != nil && errors.IsConflict(err) && time.Now().Before(retryDeadline) { + refresh = true + time.Sleep(conflictRetryInterval) + continue + } + + if err != nil { + log.Err(err). + Str("condition_type", string(conditionType)). + Str("node_name", node.Name). + Msg("Error while removing condition from node") + return err + } + log.Info(). + Str("condition_type", string(conditionType)). + Str("node_name", node.Name). + Msg("Successfully removed condition from node") + return nil + } +} + // IsLabeledWithAction will return true if the current node is labeled with NTH action labels func (n Node) IsLabeledWithAction(nodeName string) (bool, error) { k8sNode, err := n.fetchKubernetesNode(nodeName) diff --git a/pkg/node/node_test.go b/pkg/node/node_test.go index 052772a6..5a8b820d 100644 --- a/pkg/node/node_test.go +++ b/pkg/node/node_test.go @@ -542,3 +542,222 @@ func TestTaintOutOfService(t *testing.T) { } h.Equals(t, true, taintFound) } + +func TestSetDrainingConditionSuccess(t *testing.T) { + client := fake.NewSimpleClientset() + _, err := client.CoreV1().Nodes().Create( + context.Background(), + &v1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: nodeName}, + }, + metav1.CreateOptions{}) + h.Ok(t, err) + + tNode, err := newNode(config.Config{}, client) + h.Ok(t, err) + + err = tNode.SetDrainingCondition(nodeName, "SpotInterruption", "Node is being drained due to spot interruption") + h.Ok(t, err) + + updatedNode, err := client.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + h.Ok(t, err) + + conditionFound := false + for _, condition := range updatedNode.Status.Conditions { + if condition.Type == node.TerminationHandlerDrainingConditionType { + h.Equals(t, corev1.ConditionTrue, condition.Status) + h.Equals(t, "SpotInterruption", condition.Reason) + h.Equals(t, "Node is being drained due to spot interruption", condition.Message) + conditionFound = true + break + } + } + h.Equals(t, true, conditionFound) +} + +func TestSetDrainingConditionNodeNotFound(t *testing.T) { + client := fake.NewSimpleClientset() + tNode, err := newNode(config.Config{}, client) + h.Ok(t, err) + + err = tNode.SetDrainingCondition(nodeName, "SpotInterruption", "Node is being drained") + h.Assert(t, err != nil, "Expected error when node not found") +} + +func TestSetDrainingConditionDryRun(t *testing.T) { + client := fake.NewSimpleClientset() + _, err := client.CoreV1().Nodes().Create( + context.Background(), + &v1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: nodeName}, + }, + metav1.CreateOptions{}) + h.Ok(t, err) + + tNode, err := newNode(config.Config{DryRun: true}, client) + h.Ok(t, err) + + err = tNode.SetDrainingCondition(nodeName, "SpotInterruption", "Node is being drained") + h.Ok(t, err) + + // Verify condition was NOT added in dry-run mode + updatedNode, err := client.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + h.Ok(t, err) + + conditionFound := false + for _, condition := range updatedNode.Status.Conditions { + if condition.Type == node.TerminationHandlerDrainingConditionType { + conditionFound = true + break + } + } + h.Equals(t, false, conditionFound) +} + +func TestRemoveDrainingConditionSuccess(t *testing.T) { + client := fake.NewSimpleClientset() + _, err := client.CoreV1().Nodes().Create( + context.Background(), + &v1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: nodeName}, + Status: v1.NodeStatus{ + Conditions: []v1.NodeCondition{ + { + Type: node.TerminationHandlerDrainingConditionType, + Status: corev1.ConditionTrue, + Reason: "SpotInterruption", + Message: "Node is being drained", + }, + }, + }, + }, + metav1.CreateOptions{}) + h.Ok(t, err) + + tNode, err := newNode(config.Config{}, client) + h.Ok(t, err) + + err = tNode.RemoveDrainingCondition(nodeName) + h.Ok(t, err) + + updatedNode, err := client.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + h.Ok(t, err) + + conditionFound := false + for _, condition := range updatedNode.Status.Conditions { + if condition.Type == node.TerminationHandlerDrainingConditionType { + conditionFound = true + break + } + } + h.Equals(t, false, conditionFound) +} + +func TestRemoveDrainingConditionNotPresent(t *testing.T) { + client := fake.NewSimpleClientset() + _, err := client.CoreV1().Nodes().Create( + context.Background(), + &v1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: nodeName}, + }, + metav1.CreateOptions{}) + h.Ok(t, err) + + tNode, err := newNode(config.Config{}, client) + h.Ok(t, err) + + // Should not error when condition is not present + err = tNode.RemoveDrainingCondition(nodeName) + h.Ok(t, err) +} + +func TestRemoveDrainingConditionNodeNotFound(t *testing.T) { + client := fake.NewSimpleClientset() + tNode, err := newNode(config.Config{}, client) + h.Ok(t, err) + + err = tNode.RemoveDrainingCondition(nodeName) + h.Assert(t, err != nil, "Expected error when node not found") +} + +func TestRemoveDrainingConditionDryRun(t *testing.T) { + client := fake.NewSimpleClientset() + _, err := client.CoreV1().Nodes().Create( + context.Background(), + &v1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: nodeName}, + Status: v1.NodeStatus{ + Conditions: []v1.NodeCondition{ + { + Type: node.TerminationHandlerDrainingConditionType, + Status: corev1.ConditionTrue, + Reason: "SpotInterruption", + Message: "Node is being drained", + }, + }, + }, + }, + metav1.CreateOptions{}) + h.Ok(t, err) + + tNode, err := newNode(config.Config{DryRun: true}, client) + h.Ok(t, err) + + err = tNode.RemoveDrainingCondition(nodeName) + h.Ok(t, err) + + // Verify condition was NOT removed in dry-run mode + updatedNode, err := client.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + h.Ok(t, err) + + conditionFound := false + for _, condition := range updatedNode.Status.Conditions { + if condition.Type == node.TerminationHandlerDrainingConditionType { + conditionFound = true + break + } + } + h.Equals(t, true, conditionFound) +} + +func TestSetDrainingConditionUpdateExisting(t *testing.T) { + client := fake.NewSimpleClientset() + _, err := client.CoreV1().Nodes().Create( + context.Background(), + &v1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: nodeName}, + Status: v1.NodeStatus{ + Conditions: []v1.NodeCondition{ + { + Type: node.TerminationHandlerDrainingConditionType, + Status: corev1.ConditionTrue, + Reason: "OldReason", + Message: "Old message", + }, + }, + }, + }, + metav1.CreateOptions{}) + h.Ok(t, err) + + tNode, err := newNode(config.Config{}, client) + h.Ok(t, err) + + err = tNode.SetDrainingCondition(nodeName, "NewReason", "New message") + h.Ok(t, err) + + updatedNode, err := client.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + h.Ok(t, err) + + conditionFound := false + for _, condition := range updatedNode.Status.Conditions { + if condition.Type == node.TerminationHandlerDrainingConditionType { + h.Equals(t, corev1.ConditionTrue, condition.Status) + h.Equals(t, "NewReason", condition.Reason) + h.Equals(t, "New message", condition.Message) + conditionFound = true + break + } + } + h.Equals(t, true, conditionFound) +}