diff --git a/auth/host_session.go b/auth/host_session.go index 04e25009e..50cfb9f6e 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -4,8 +4,6 @@ import ( "context" "encoding/json" "fmt" - "log/slog" - "strings" "time" "github.com/google/uuid" @@ -13,10 +11,10 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/logic/pro/netcache" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" ) @@ -236,18 +234,18 @@ func SessionHandler(conn *websocket.Conn) { } // CheckNetRegAndHostUpdate - run through networks and send a host update -func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *schema.Host, username string) { +func CheckNetRegAndHostUpdate(key models.EnrollmentKey, host *schema.Host, username string) { // publish host update through MQ featureFlags := logic.GetFeatureFlags() for _, netID := range key.Networks { network := &schema.Network{Name: netID} if err := network.Get(db.WithContext(context.TODO())); err == nil { if featureFlags.EnableDeviceApproval && !network.AutoJoin { - if logic.DoesHostExistinTheNetworkAlready(h, schema.NetworkID(netID)) { + if logic.DoesHostExistInTheNetworkAlready(host, schema.NetworkID(netID)) { continue } if err := (&schema.PendingHost{ - HostID: h.ID.String(), + HostID: host.ID.String(), Network: netID, }).CheckIfPendingHostExists(db.WithContext(context.TODO())); err == nil { continue @@ -256,13 +254,13 @@ func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *schema.Host, username // add host to pending host table p := schema.PendingHost{ ID: uuid.NewString(), - HostID: h.ID.String(), - Hostname: h.Name, + HostID: host.ID.String(), + Hostname: host.Name, Network: netID, - PublicKey: h.PublicKey.String(), - OS: h.OS, - Location: h.Location, - Version: h.Version, + PublicKey: host.PublicKey.String(), + OS: host.OS, + Location: host.Location, + Version: host.Version, EnrollmentKey: keyB, RequestedAt: time.Now().UTC(), } @@ -270,102 +268,58 @@ func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *schema.Host, username continue } - if len(username) > 0 { - logic.LogEvent(&models.Event{ - Action: schema.JoinHostToNet, - Source: models.Subject{ - ID: username, - Name: username, - Type: schema.UserSub, - }, - TriggeredBy: username, - Target: models.Subject{ - ID: h.ID.String(), - Name: h.Name, - Type: schema.DeviceSub, - }, - NetworkID: schema.NetworkID(netID), - Origin: schema.Dashboard, - }) + _, err := orchestrator.GetRepository().NodeOrchestrator().CreateNode( + db.WithContext(context.TODO()), + host, + network, + orchestrator.UseKey(&key), + orchestrator.SkipPublishPeerUpdate(), + ) + if err != nil { + logger.Log(0, fmt.Sprintf("failed to add host (%s, %s) to network (%s): %v", host.ID.String(), host.Name, netID, err.Error())) } else { - logic.LogEvent(&models.Event{ - Action: schema.JoinHostToNet, - Source: models.Subject{ - ID: key.Value, - Name: key.Tags[0], - Type: schema.EnrollmentKeySub, - }, - TriggeredBy: username, - Target: models.Subject{ - ID: h.ID.String(), - Name: h.Name, - Type: schema.DeviceSub, - }, - NetworkID: schema.NetworkID(netID), - Origin: schema.Dashboard, - }) - } - - newNode, err := logic.UpdateHostNetwork(h, netID, true) - if servercfg.IsPro && key.AutoAssignGateway { - newNode.AutoAssignGateway = true - logic.UpsertNode(newNode) - } - if err == nil || strings.Contains(err.Error(), "host already part of network") { - if len(key.Groups) > 0 { - newNode.Tags = make(map[models.TagID]struct{}) - for _, tagI := range key.Groups { - newNode.Tags[tagI] = struct{}{} - } - logic.UpsertNode(newNode) - } - if key.Relay != uuid.Nil && !newNode.IsRelayed { - // check if relay node exists and acting as relay - relaynode, err := logic.GetNodeByID(key.Relay.String()) - if err == nil && relaynode.IsGw && relaynode.Network == newNode.Network { - slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), key.Relay.String(), netID)) - newNode.IsRelayed = true - newNode.RelayedBy = key.Relay.String() - updatedRelayNode := relaynode - updatedRelayNode.RelayedNodes = append(updatedRelayNode.RelayedNodes, newNode.ID.String()) - logic.UpdateRelayed(&relaynode, &updatedRelayNode) - if err := logic.UpsertNode(&updatedRelayNode); err != nil { - slog.Error("failed to update node", "nodeid", key.Relay.String()) - } - if err := logic.UpsertNode(newNode); err != nil { - slog.Error("failed to update node", "nodeid", key.Relay.String()) - } - } else { - slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err) - } - } - if err != nil && strings.Contains(err.Error(), "host already part of network") { - continue + if len(username) > 0 { + logic.LogEvent(&models.Event{ + Action: schema.JoinHostToNet, + Source: models.Subject{ + ID: username, + Name: username, + Type: schema.UserSub, + }, + TriggeredBy: username, + Target: models.Subject{ + ID: host.ID.String(), + Name: host.Name, + Type: schema.DeviceSub, + }, + NetworkID: schema.NetworkID(netID), + Origin: schema.Dashboard, + }) + } else { + logic.LogEvent(&models.Event{ + Action: schema.JoinHostToNet, + Source: models.Subject{ + ID: key.Value, + Name: key.Tags[0], + Type: schema.EnrollmentKeySub, + }, + TriggeredBy: username, + Target: models.Subject{ + ID: host.ID.String(), + Name: host.Name, + Type: schema.DeviceSub, + }, + NetworkID: schema.NetworkID(netID), + Origin: schema.Dashboard, + }) } - } else { - logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, netID, err.Error()) - continue - } - logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name) - hostactions.AddAction(models.HostUpdate{ - Action: models.JoinHostToNetwork, - Host: *h, - Node: *newNode, - }) - if h.IsDefault { - // make host gateway - logic.CreateIngressGateway(netID, newNode.ID.String(), models.IngressRequest{}) - logic.CreateRelay(models.RelayRequest{ - NodeID: newNode.ID.String(), - NetID: netID, - }) } } } if servercfg.IsMessageQueueBackend() { mq.HostUpdate(&models.HostUpdate{ Action: models.RequestAck, - Host: *h, + Host: *host, }) if err := mq.PublishPeerUpdate(false); err != nil { logger.Log(0, "failed to publish peer update during registration -", err.Error()) diff --git a/cli/cmd/failover/disable.go b/cli/cmd/failover/disable.go deleted file mode 100644 index 886843aab..000000000 --- a/cli/cmd/failover/disable.go +++ /dev/null @@ -1,20 +0,0 @@ -package failover - -import ( - "github.com/gravitl/netmaker/cli/functions" - "github.com/spf13/cobra" -) - -var disableFailoverCmd = &cobra.Command{ - Use: "disable [NODE ID]", - Args: cobra.ExactArgs(1), - Short: "Disable failover for a given Node", - Long: `Disable failover for a given Node`, - Run: func(cmd *cobra.Command, args []string) { - functions.PrettyPrint(functions.DisableNodeFailover(args[0])) - }, -} - -func init() { - rootCmd.AddCommand(disableFailoverCmd) -} diff --git a/cli/cmd/failover/enable.go b/cli/cmd/failover/enable.go deleted file mode 100644 index d7ee6a3d2..000000000 --- a/cli/cmd/failover/enable.go +++ /dev/null @@ -1,20 +0,0 @@ -package failover - -import ( - "github.com/gravitl/netmaker/cli/functions" - "github.com/spf13/cobra" -) - -var enableFailoverCmd = &cobra.Command{ - Use: "enable [NODE ID]", - Args: cobra.ExactArgs(1), - Short: "Enable failover for a given Node", - Long: `Enable failover for a given Node`, - Run: func(cmd *cobra.Command, args []string) { - functions.PrettyPrint(functions.EnableNodeFailover(args[0])) - }, -} - -func init() { - rootCmd.AddCommand(enableFailoverCmd) -} diff --git a/cli/cmd/failover/root.go b/cli/cmd/failover/root.go deleted file mode 100644 index 390e49c58..000000000 --- a/cli/cmd/failover/root.go +++ /dev/null @@ -1,28 +0,0 @@ -package failover - -import ( - "os" - - "github.com/spf13/cobra" -) - -// rootCmd represents the base command when called without any subcommands -var rootCmd = &cobra.Command{ - Use: "failover", - Short: "Enable/Disable failover for a node associated with a network", - Long: `Enable/Disable failover for a node associated with a network`, -} - -// GetRoot returns the root subcommand -func GetRoot() *cobra.Command { - return rootCmd -} - -// Execute adds all child commands to the root command and sets flags appropriately. -// This is called by main.main(). It only needs to happen once to the rootCmd. -func Execute() { - err := rootCmd.Execute() - if err != nil { - os.Exit(1) - } -} diff --git a/cli/cmd/node/create_ingress.go b/cli/cmd/node/create_ingress.go index 717e2e81f..5b524c4e4 100644 --- a/cli/cmd/node/create_ingress.go +++ b/cli/cmd/node/create_ingress.go @@ -13,11 +13,10 @@ var nodeCreateIngressCmd = &cobra.Command{ Deprecated: "in favour of the `gateway` subcommand, in Netmaker v0.90.0.", Aliases: []string{"create_rag"}, Run: func(cmd *cobra.Command, args []string) { - functions.PrettyPrint(functions.CreateIngress(args[0], args[1], failover)) + functions.PrettyPrint(functions.CreateIngress(args[0], args[1])) }, } func init() { - nodeCreateIngressCmd.Flags().BoolVar(&failover, "failover", false, "Enable FailOver ?") rootCmd.AddCommand(nodeCreateIngressCmd) } diff --git a/cli/cmd/node/flags.go b/cli/cmd/node/flags.go index 2ed805d75..504dd849c 100644 --- a/cli/cmd/node/flags.go +++ b/cli/cmd/node/flags.go @@ -2,7 +2,6 @@ package node var ( natEnabled bool - failover bool networkName string nodeDefinitionFilePath string address string diff --git a/cli/cmd/root.go b/cli/cmd/root.go index 2ba503e10..9045f3095 100644 --- a/cli/cmd/root.go +++ b/cli/cmd/root.go @@ -9,7 +9,6 @@ import ( "github.com/gravitl/netmaker/cli/cmd/dns" "github.com/gravitl/netmaker/cli/cmd/enrollment_key" "github.com/gravitl/netmaker/cli/cmd/ext_client" - "github.com/gravitl/netmaker/cli/cmd/failover" "github.com/gravitl/netmaker/cli/cmd/gateway" "github.com/gravitl/netmaker/cli/cmd/host" "github.com/gravitl/netmaker/cli/cmd/metrics" @@ -55,7 +54,6 @@ func init() { rootCmd.AddCommand(metrics.GetRoot()) rootCmd.AddCommand(host.GetRoot()) rootCmd.AddCommand(enrollment_key.GetRoot()) - rootCmd.AddCommand(failover.GetRoot()) rootCmd.AddCommand(gateway.GetRoot()) rootCmd.AddCommand(access_token.GetRoot()) } diff --git a/cli/functions/failover.go b/cli/functions/failover.go deleted file mode 100644 index 1c1b3f767..000000000 --- a/cli/functions/failover.go +++ /dev/null @@ -1,18 +0,0 @@ -package functions - -import ( - "fmt" - "net/http" - - "github.com/gravitl/netmaker/models" -) - -// EnableNodeFailover - Enable failover for a given Node -func EnableNodeFailover(nodeID string) *models.SuccessResponse { - return request[models.SuccessResponse](http.MethodPost, fmt.Sprintf("/api/v1/node/%s/failover", nodeID), nil) -} - -// DisableNodeFailover - Disable failover for a given Node -func DisableNodeFailover(nodeID string) *models.SuccessResponse { - return request[models.SuccessResponse](http.MethodDelete, fmt.Sprintf("/api/v1/node/%s/failover", nodeID), nil) -} diff --git a/cli/functions/node.go b/cli/functions/node.go index 50eed3572..6e3a73950 100644 --- a/cli/functions/node.go +++ b/cli/functions/node.go @@ -42,10 +42,10 @@ func DeleteEgress(networkName, nodeID string) *models.ApiNode { } // CreateIngress - turn a node into an ingress -func CreateIngress(networkName, nodeID string, failover bool) *models.ApiNode { +func CreateIngress(networkName, nodeID string) *models.ApiNode { return request[models.ApiNode](http.MethodPost, fmt.Sprintf("/api/nodes/%s/%s/createingress", networkName, nodeID), &struct { Failover bool `json:"failover"` - }{Failover: failover}) + }{Failover: false}) } // DeleteIngress - remove ingress role from a node diff --git a/controllers/controller.go b/controllers/controller.go index d5b6e215e..9c7f7358f 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -40,7 +40,6 @@ var HttpHandlers = []interface{}{ enrollmentKeyHandlers, aclHandlers, egressHandlers, - legacyHandlers, } func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) { diff --git a/controllers/dns.go b/controllers/dns.go index f8066e23b..37bcafbe8 100644 --- a/controllers/dns.go +++ b/controllers/dns.go @@ -516,15 +516,6 @@ func createDNS(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - if servercfg.IsDNSMode() { - err = logic.SetDNS() - if err != nil { - logger.Log(0, r.Header.Get("user"), - fmt.Sprintf("Failed to set DNS entries on file: %v", err)) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - } if logic.GetManageDNS() { mq.SendDNSSyncByNetwork(netID) @@ -562,15 +553,6 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) { return } logger.Log(1, "deleted dns entry: ", entrytext) - if servercfg.IsDNSMode() { - err = logic.SetDNS() - if err != nil { - logger.Log(0, r.Header.Get("user"), - fmt.Sprintf("Failed to set DNS entries on file: %v", err)) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - } if logic.GetManageDNS() { mq.SendDNSSyncByNetwork(netID) @@ -614,14 +596,9 @@ func pushDNS(w http.ResponseWriter, r *http.Request) { ) return } - err := logic.SetDNS() - if err != nil { - logger.Log(0, r.Header.Get("user"), - fmt.Sprintf("Failed to set DNS entries on file: %v", err)) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } + // TODO: deprecate API. does nothing. + logger.Log(1, r.Header.Get("user"), "pushed DNS updates to nameserver") json.NewEncoder(w).Encode("DNS Pushed to CoreDNS") } diff --git a/controllers/dns_test.go b/controllers/dns_test.go index 334348ab7..9fbeb846b 100644 --- a/controllers/dns_test.go +++ b/controllers/dns_test.go @@ -2,7 +2,6 @@ package controller import ( "fmt" - "net" "testing" "github.com/google/uuid" @@ -10,7 +9,6 @@ import ( "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/gravitl/netmaker/functions" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" ) @@ -47,51 +45,6 @@ func TestGetAllDNS(t *testing.T) { }) } -func TestGetNodeDNS(t *testing.T) { - deleteAllDNS(t) - deleteAllNetworks() - createNet() - createHost() - err := functions.SetDNSDir() - assert.Nil(t, err) - t.Run("NoNodes", func(t *testing.T) { - dns, _ := logic.GetNodeDNS("skynet") - assert.Equal(t, []models.DNSEntry(nil), dns) - }) - t.Run("NodeExists", func(t *testing.T) { - createHost() - _, ipnet, _ := net.ParseCIDR("10.0.0.1/32") - tmpCNode := models.CommonNode{ - ID: uuid.New(), - Network: "skynet", - Address: *ipnet, - } - createnode := models.Node{ - CommonNode: tmpCNode, - } - err := logic.AssociateNodeToHost(&createnode, &dnsHost) - assert.Nil(t, err) - dns, err := logic.GetNodeDNS("skynet") - assert.Nil(t, err) - assert.Equal(t, "10.0.0.1", dns[0].Address) - }) - t.Run("MultipleNodes", func(t *testing.T) { - _, ipnet, _ := net.ParseCIDR("10.100.100.3/32") - tmpCNode := models.CommonNode{ - ID: uuid.New(), - Network: "skynet", - Address: *ipnet, - } - createnode := models.Node{ - CommonNode: tmpCNode, - } - err := logic.AssociateNodeToHost(&createnode, &dnsHost) - assert.Nil(t, err) - dns, err := logic.GetNodeDNS("skynet") - assert.Nil(t, err) - assert.Equal(t, 2, len(dns)) - }) -} func TestGetCustomDNS(t *testing.T) { deleteAllDNS(t) deleteAllNetworks() diff --git a/controllers/ext_client.go b/controllers/ext_client.go index 63f2da9f1..73b1d9a21 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -7,9 +7,9 @@ import ( "fmt" "net" "net/http" - "strconv" "strings" "sync" + "time" "github.com/go-playground/validator/v10" "github.com/gorilla/mux" @@ -17,10 +17,9 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/schema" - "github.com/gravitl/netmaker/servercfg" - "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/orchestrator" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/mq" "github.com/skip2/go-qrcode" @@ -301,11 +300,8 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) { } keepalive := "" - if network.DefaultKeepAlive != 0 { - keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepAlive)) - } - if gwnode.IngressPersistentKeepalive != 0 { - keepalive = "PersistentKeepalive = " + strconv.Itoa(int(gwnode.IngressPersistentKeepalive)) + if host.PersistentKeepalive != 0 { + keepalive = fmt.Sprintf("PersistentKeepalive = %d", int(host.PersistentKeepalive.Seconds())) } gwendpoint := "" @@ -350,9 +346,6 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) { if host.MTU != 0 { defaultMTU = host.MTU } - if gwnode.IngressMTU != 0 { - defaultMTU = int(gwnode.IngressMTU) - } postUp := strings.Builder{} if client.PostUp != "" && params["type"] != "qr" { @@ -429,237 +422,6 @@ Endpoint = %s json.NewEncoder(w).Encode(client) } -// @Summary Get config file HA configuration -// @Router /api/v1/client_conf/{network} [get] -// @Tags Config Files -// @Security oauth -// @Param network path string true "Network ID" -// @Success 200 {string} string "WireGuard config file" -// @Failure 500 {object} models.ErrorResponse -// @Failure 403 {object} models.ErrorResponse -func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) { - - var params = mux.Vars(r) - networkid := params["network"] - network := &schema.Network{Name: networkid} - err := network.Get(r.Context()) - if err != nil { - logger.Log( - 1, - r.Header.Get("user"), - "Could not retrieve Ingress Gateway Network", - networkid, - ) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - // fetch client based on availability - nodes, _ := logic.GetNetworkNodes(networkid) - defaultPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(networkid), models.DevicePolicy) - var targetGwID string - var connectionCnt int = -1 - for _, nodeI := range nodes { - if nodeI.IsGw { - // check health status - logic.GetNodeStatus(&nodeI, defaultPolicy.Enabled) - if nodeI.Status != models.OnlineSt { - continue - } - // Get Total connections on the gw - clients := logic.GetGwExtclients(nodeI.ID.String(), networkid) - - if connectionCnt == -1 || len(clients) < connectionCnt { - connectionCnt = len(clients) - targetGwID = nodeI.ID.String() - } - - } - } - gwnode, err := logic.GetNodeByID(targetGwID) - if err != nil { - logger.Log( - 0, - r.Header.Get("user"), - fmt.Sprintf( - "failed to get ingress gateway node [%s] info: %v", - gwnode.ID, - err, - ), - ) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - host := &schema.Host{ - ID: gwnode.HostID, - } - err = host.Get(r.Context()) - if err != nil { - logger.Log(0, r.Header.Get("user"), - fmt.Sprintf("failed to get ingress gateway host for node [%s] info: %v", gwnode.ID, err)) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - - var userName string - if r.Header.Get("ismaster") == "yes" { - userName = logic.MasterUser - } else { - caller := &schema.User{Username: r.Header.Get("user")} - err = caller.Get(r.Context()) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - userName = caller.Username - } - // create client - var extclient models.ExtClient - extclient.OwnerID = userName - extclient.IngressGatewayID = targetGwID - extclient.Network = networkid - extclient.Tags = make(map[models.TagID]struct{}) - - listenPort := logic.GetPeerListenPort(host) - extclient.IngressGatewayEndpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), listenPort) - extclient.Enabled = true - - if err = logic.CreateExtClient(&extclient); err != nil { - slog.Error( - "failed to create extclient", - "user", - r.Header.Get("user"), - "network", - networkid, - "error", - err, - ) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - client, err := logic.GetExtClient(extclient.ClientID, networkid) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - logic.SetDNSOnWgConfig(&gwnode, &client) - defaultDNS := "" - if client.DNS != "" { - defaultDNS = "DNS = " + client.DNS - } - addrString := client.Address - if addrString != "" { - addrString += "/32" - } - if client.Address6 != "" { - if addrString != "" { - addrString += "," - } - addrString += client.Address6 + "/128" - } - - keepalive := "" - if network.DefaultKeepAlive != 0 { - keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepAlive)) - } - if gwnode.IngressPersistentKeepalive != 0 { - keepalive = "PersistentKeepalive = " + strconv.Itoa(int(gwnode.IngressPersistentKeepalive)) - } - var newAllowedIPs string - if logic.IsInternetGw(gwnode) || gwnode.InternetGwID != "" { - egressrange := "0.0.0.0/0" - if gwnode.Address6.IP != nil && client.Address6 != "" { - egressrange += "," + "::/0" - } - newAllowedIPs = egressrange - } else { - newAllowedIPs = network.AddressRange - if newAllowedIPs != "" && network.AddressRange6 != "" { - newAllowedIPs += "," - } - if network.AddressRange6 != "" { - newAllowedIPs += network.AddressRange6 - } - if egressGatewayRanges, err := logic.GetEgressRangesOnNetwork(&client); err == nil { - for _, egressGatewayRange := range egressGatewayRanges { - newAllowedIPs += "," + egressGatewayRange - } - } - } - gwendpoint := "" - if host.EndpointIP.To4() == nil { - gwendpoint = fmt.Sprintf("[%s]:%d", host.EndpointIPv6.String(), host.ListenPort) - } else { - gwendpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), host.ListenPort) - } - defaultMTU := 1420 - if host.MTU != 0 { - defaultMTU = host.MTU - } - if gwnode.IngressMTU != 0 { - defaultMTU = int(gwnode.IngressMTU) - } - - postUp := strings.Builder{} - if client.PostUp != "" && params["type"] != "qr" { - for _, loc := range strings.Split(client.PostUp, "\n") { - postUp.WriteString(fmt.Sprintf("PostUp = %s\n", loc)) - } - } - - postDown := strings.Builder{} - if client.PostDown != "" && params["type"] != "qr" { - for _, loc := range strings.Split(client.PostDown, "\n") { - postDown.WriteString(fmt.Sprintf("PostDown = %s\n", loc)) - } - } - - config := fmt.Sprintf(`[Interface] -Address = %s -PrivateKey = %s -MTU = %d -%s -%s -%s - -[Peer] -PublicKey = %s -AllowedIPs = %s -Endpoint = %s -%s - -`, addrString, - client.PrivateKey, - defaultMTU, - defaultDNS, - postUp.String(), - postDown.String(), - host.PublicKey, - newAllowedIPs, - gwendpoint, - keepalive, - ) - - go func() { - if err := mq.PublishPeerUpdate(false); err != nil { - logger.Log(1, "error publishing peer update ", err.Error()) - } - if servercfg.IsDNSMode() { - logic.SetDNS() - } - }() - - name := client.ClientID + ".conf" - w.Header().Set("Content-Type", "application/config") - w.Header().Set("Client-ID", client.ClientID) - w.Header().Set("Content-Disposition", "attachment; filename=\""+name+"\"") - w.WriteHeader(http.StatusOK) - _, err = fmt.Fprint(w, config) - if err != nil { - logger.Log(1, r.Header.Get("user"), "response writer error (file) ", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - } -} - // @Summary Create a config file // @Router /api/extclients/{network}/{nodeid} [post] // @Tags Config Files @@ -846,7 +608,106 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { return } } - if err = logic.CreateExtClient(&extclient); err != nil { + + if len(extclient.PublicKey) == 0 { + privateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + extclient.PrivateKey = privateKey.String() + extclient.PublicKey = privateKey.PublicKey().String() + } else if len(extclient.PrivateKey) == 0 && len(extclient.PublicKey) > 0 { + extclient.PrivateKey = "[ENTER PRIVATE KEY]" + } + if extclient.ExtraAllowedIPs == nil { + extclient.ExtraAllowedIPs = []string{} + } + + parentNetwork := &schema.Network{Name: extclient.Network} + err = parentNetwork.Get(db.WithContext(context.TODO())) + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + + if extclient.Address == "" { + if parentNetwork.AddressRange != "" { + newAddress, err := orchestrator.GetRepository().NetworkOrchestrator().AllocateExtclientIP(db.WithContext(context.TODO()), parentNetwork) + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + extclient.Address = newAddress.String() + } + } + + if extclient.Address6 == "" { + if parentNetwork.AddressRange6 != "" { + addr6, err := orchestrator.GetRepository().NetworkOrchestrator().AllocateExtclientIPv6(db.WithContext(context.TODO()), parentNetwork) + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + extclient.Address6 = addr6.String() + } + } + + if extclient.ClientID == "" { + extclient.ClientID, err = logic.GenerateNodeName(extclient.Network) + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + } + + extclient.LastModified = time.Now().Unix() + if err = logic.SaveExtClient(&extclient); err != nil { slog.Error( "failed to create extclient", "user", @@ -915,9 +776,6 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { extUpdateMutex.Lock() mq.PublishPeerUpdate(false) extUpdateMutex.Unlock() - if servercfg.IsDNSMode() { - logic.SetDNS() - } }() } @@ -986,8 +844,6 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { return } - var changedID = update.ClientID != oldExtClient.ClientID - if update.PublicKey != oldExtClient.PublicKey { //remove old peer entry replacePeers = true @@ -1072,9 +928,6 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(newclient) go func() { - if changedID && servercfg.IsDNSMode() { - logic.SetDNS() - } if replacePeers || !update.Enabled { if err := mq.PublishDeletedClientPeerUpdate(&oldExtClient); err != nil { slog.Error("error deleting old ext peers", "error", err.Error()) @@ -1147,9 +1000,6 @@ func deleteExtClient(w http.ResponseWriter, r *http.Request) { if err := mq.PublishDeletedClientPeerUpdate(&extclient); err != nil { slog.Error("error setting ext peers on " + ingressnode.ID.String() + ": " + err.Error()) } - if servercfg.IsDNSMode() { - logic.SetDNS() - } }() logger.Log(0, r.Header.Get("user"), @@ -1236,9 +1086,6 @@ func bulkDeleteExtClients(w http.ResponseWriter, r *http.Request) { } go mq.PublishPeerUpdate(false) - if servercfg.IsDNSMode() { - logic.SetDNS() - } } slog.Info("bulk extclient delete completed", "deleted", deleted, "total", len(req.IDs)) }() @@ -1381,9 +1228,6 @@ func bulkUpdateExtClientStatus(w http.ResponseWriter, r *http.Request) { } if updated > 0 { mq.PublishPeerUpdate(false) - if servercfg.IsDNSMode() { - logic.SetDNS() - } } slog.Info("bulk extclient status completed", "action", eventAction, "updated", updated, "total", len(req.IDs)) }() diff --git a/controllers/gateway.go b/controllers/gateway.go index ab8c992e0..d787fd7ab 100644 --- a/controllers/gateway.go +++ b/controllers/gateway.go @@ -8,7 +8,6 @@ import ( "net/http" "slices" - "github.com/google/uuid" "github.com/gorilla/mux" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" @@ -150,9 +149,6 @@ func createGateway(w http.ResponseWriter, r *http.Request) { for _, relayedNodeID := range relayNode.RelayedNodes { relayedNode, err := logic.GetNodeByID(relayedNodeID) if err == nil { - if relayedNode.FailedOverBy != uuid.Nil { - logic.ResetFailedOverPeer(&relayedNode) - } if len(relayedNode.AutoRelayedPeers) > 0 { logic.ResetAutoRelayedPeer(&relayedNode) } @@ -281,10 +277,6 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) { err, ) } - if servercfg.IsDNSMode() { - logic.SetDNS() - } - } logic.RemoveNodeFromEnrollmentKeys(&node) @@ -376,9 +368,6 @@ func assignGw(w http.ResponseWriter, r *http.Request) { logic.UpsertNode(&node) logic.GetNodeStatus(&node, false) go func() { - if node.FailedOverBy != uuid.Nil { - logic.ResetFailedOverPeer(&node) - } if len(node.AutoRelayedPeers) > 0 { logic.ResetAutoRelayedPeer(&node) } @@ -449,10 +438,6 @@ func assignGw(w http.ResponseWriter, r *http.Request) { apiNode := node.ConvertToAPINode() go func() { - - if node.FailedOverBy != uuid.Nil { - logic.ResetFailedOverPeer(&node) - } if len(node.AutoRelayedPeers) > 0 { logic.ResetAutoRelayedPeer(&node) } diff --git a/controllers/hosts.go b/controllers/hosts.go index f315f2c27..573e816eb 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -18,6 +18,7 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" @@ -318,7 +319,6 @@ func pull(w http.ResponseWriter, r *http.Request) { if err != nil { continue } - logic.ResetFailedOverPeer(&node) logic.ResetAutoRelayedPeer(&node) } go mq.PublishPeerUpdate(false) @@ -446,11 +446,6 @@ func updateHost(w http.ResponseWriter, r *http.Request) { if err := mq.PublishPeerUpdate(false); err != nil { logger.Log(0, "fail to publish peer update: ", err.Error()) } - if newHost.Name != currHost.Name { - if servercfg.IsDNSMode() { - logic.SetDNS() - } - } }() logic.LogEvent(&models.Event{ @@ -816,8 +811,8 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) { var params = mux.Vars(r) hostIDStr := params["hostid"] - network := params["network"] - if hostIDStr == "" || network == "" { + networkID := params["network"] + if hostIDStr == "" || networkID == "" { logic.ReturnErrorResponse( w, r, @@ -833,67 +828,54 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) { } // confirm host exists - currHost := &schema.Host{ + host := &schema.Host{ ID: hostID, } - err = currHost.Get(r.Context()) + err = host.Get(r.Context()) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to find host:", hostIDStr, err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) return } + network := &schema.Network{ + Name: networkID, + } + err = network.Get(r.Context()) + if err != nil { + err = fmt.Errorf("failed to add host (%s) to network (%s): error getting network: %v", hostID, networkID, err) + logger.Log(0, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{ - ClientLocation: currHost.CountryCode, - ClientVersion: currHost.Version, - OS: currHost.OS, - OSFamily: currHost.OSFamily, - OSVersion: currHost.OSVersion, - KernelVersion: currHost.KernelVersion, + ClientLocation: host.CountryCode, + ClientVersion: host.Version, + OS: host.OS, + OSFamily: host.OSFamily, + OSVersion: host.OSVersion, + KernelVersion: host.KernelVersion, SkipAutoUpdate: true, - }, schema.NetworkID(network)) + }, schema.NetworkID(networkID)) if len(violations) > 0 { logic.ReturnErrorResponseWithJson(w, r, violations, logic.FormatError(errors.New("posture check violations"), logic.BadReq)) return } - newNode, err := logic.UpdateHostNetwork(currHost, network, true) + + _, err = orchestrator.GetRepository().NodeOrchestrator().CreateNode(r.Context(), host, network) if err != nil { - logger.Log( - 0, - r.Header.Get("user"), - "failed to add host to network:", - hostIDStr, - network, - err.Error(), - ) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + err = fmt.Errorf("failed to add host (%s) to network (%s): error creating node: %v", hostID, networkID, err) + logger.Log(0, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) return } - logger.Log(1, "added new node", newNode.ID.String(), "to host", currHost.Name) - if currHost.IsDefault { - // make host gateway - logic.CreateIngressGateway(network, newNode.ID.String(), models.IngressRequest{}) - logic.CreateRelay(models.RelayRequest{ - NodeID: newNode.ID.String(), - NetID: network, - }) - } - go func() { - mq.HostUpdate(&models.HostUpdate{ - Action: models.JoinHostToNetwork, - Host: *currHost, - Node: *newNode, - }) - mq.PublishPeerUpdate(false) - if servercfg.IsDNSMode() { - logic.SetDNS() - } - }() + logger.Log( 2, r.Header.Get("user"), - fmt.Sprintf("added host %s to network %s", currHost.Name, network), + fmt.Sprintf("added host %s to network %s", host.Name, networkID), ) logic.LogEvent(&models.Event{ Action: schema.JoinHostToNet, @@ -904,11 +886,11 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) { }, TriggeredBy: r.Header.Get("user"), Target: models.Subject{ - ID: currHost.ID.String(), - Name: currHost.Name, + ID: host.ID.String(), + Name: host.Name, Type: schema.DeviceSub, }, - NetworkID: schema.NetworkID(network), + NetworkID: schema.NetworkID(networkID), Origin: schema.Dashboard, }) w.WriteHeader(http.StatusOK) @@ -1043,9 +1025,6 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) { } go func() { mq.PublishMqUpdatesForDeletedNode(*node, true) - if servercfg.IsDNSMode() { - logic.SetDNS() - } }() logic.LogEvent(&models.Event{ Action: schema.RemoveHostFromNet, @@ -1603,10 +1582,10 @@ func approvePendingHost(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) return } - h := &schema.Host{ + host := &schema.Host{ ID: hostID, } - err = h.Get(r.Context()) + err = host.Get(r.Context()) if err != nil { logic.ReturnErrorResponse(w, r, models.ErrorResponse{ Code: http.StatusBadRequest, @@ -1616,66 +1595,28 @@ func approvePendingHost(w http.ResponseWriter, r *http.Request) { } key := models.EnrollmentKey{} json.Unmarshal(p.EnrollmentKey, &key) - newNode, err := logic.UpdateHostNetwork(h, p.Network, true) + + network := &schema.Network{ + Name: p.Network, + } + err = network.Get(r.Context()) if err != nil { - logic.ReturnErrorResponse(w, r, models.ErrorResponse{ - Code: http.StatusBadRequest, - Message: err.Error(), - }) + err = fmt.Errorf("failed to approve pending host (%s): error getting network (%s): %w", id, p.Network, err) + logger.Log(0, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) return } - if key.AutoAssignGateway { - newNode.AutoAssignGateway = true - } - if len(key.Groups) > 0 { - newNode.Tags = make(map[models.TagID]struct{}) - for _, tagI := range key.Groups { - newNode.Tags[tagI] = struct{}{} - } - } - if key.Relay != uuid.Nil && !newNode.IsRelayed { - // check if relay node exists and acting as relay - relaynode, err := logic.GetNodeByID(key.Relay.String()) - if err == nil && relaynode.IsGw && relaynode.Network == newNode.Network { - slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), key.Relay.String(), p.Network)) - newNode.IsRelayed = true - newNode.RelayedBy = key.Relay.String() - updatedRelayNode := relaynode - updatedRelayNode.RelayedNodes = append(updatedRelayNode.RelayedNodes, newNode.ID.String()) - logic.UpdateRelayed(&relaynode, &updatedRelayNode) - if err := logic.UpsertNode(&updatedRelayNode); err != nil { - slog.Error("failed to update node", "nodeid", key.Relay.String()) - } - } else { - slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err) - } - } - err = logic.UpsertNode(newNode) + newNode, err := orchestrator.GetRepository().NodeOrchestrator().CreateNode(r.Context(), host, network, orchestrator.UseKey(&key)) if err != nil { - err = fmt.Errorf("failed to update node: %w", err) - slog.Error("failed to update node", "nodeid", newNode.ID.String()) + err = fmt.Errorf("failed to approve pending host (%s): error creating node: %w", id, err) + logger.Log(0, err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) return } - logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name) - mq.HostUpdate(&models.HostUpdate{ - Action: models.JoinHostToNetwork, - Host: *h, - Node: *newNode, - }) - if h.IsDefault { - // make host gateway - logic.CreateIngressGateway(p.Network, newNode.ID.String(), models.IngressRequest{}) - logic.CreateRelay(models.RelayRequest{ - NodeID: newNode.ID.String(), - NetID: p.Network, - }) - } p.Delete(db.WithContext(r.Context())) - go mq.PublishPeerUpdate(false) - logic.ReturnSuccessResponseWithJson(w, r, newNode.ConvertToAPINode(), "added pending host to "+p.Network) + logic.ReturnSuccessResponseWithJson(w, r, logic.ConvertSchemaNodeToApiNode(newNode), "added pending host to "+p.Network) } // @Summary Reject pending host in a network @@ -1721,29 +1662,10 @@ func addDefaultHostToNetworks(host *schema.Host) { if !network.AutoJoin { continue } - newNode, err := logic.UpdateHostNetwork(host, network.Name, true) + _, err := orchestrator.GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, &network, orchestrator.SkipPublishPeerUpdate()) if err != nil { logger.Log(2, "skipping network", network.Name, "for default host", host.Name, ":", err.Error()) continue } - logger.Log(1, "added default host", host.Name, "to network", network.Name) - if len(host.Nodes) == 1 { - mq.HostUpdate(&models.HostUpdate{ - Action: models.RequestPull, - Host: *host, - Node: *newNode, - }) - } else { - mq.HostUpdate(&models.HostUpdate{ - Action: models.JoinHostToNetwork, - Host: *host, - Node: *newNode, - }) - } - logic.CreateIngressGateway(network.Name, newNode.ID.String(), models.IngressRequest{}) - logic.CreateRelay(models.RelayRequest{ - NodeID: newNode.ID.String(), - NetID: network.Name, - }) } } diff --git a/controllers/inet_gws.go b/controllers/inet_gws.go deleted file mode 100644 index 399274511..000000000 --- a/controllers/inet_gws.go +++ /dev/null @@ -1,176 +0,0 @@ -package controller - -import ( - "encoding/json" - "errors" - "net/http" - - "github.com/gorilla/mux" - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/mq" - "github.com/gravitl/netmaker/schema" - "github.com/gravitl/netmaker/servercfg" -) - -func createInternetGw(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - w.Header().Set("Content-Type", "application/json") - nodeid := params["nodeid"] - netid := params["network"] - node, err := logic.ValidateParams(nodeid, netid) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - if node.IsInternetGateway { - logic.ReturnSuccessResponse(w, r, "node is already acting as internet gateway") - return - } - var request models.InetNodeReq - err = json.NewDecoder(r.Body).Decode(&request) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - host := &schema.Host{ - ID: node.HostID, - } - err = host.Get(r.Context()) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - if host.OS != models.OS_Types.Linux { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - errors.New("only linux nodes can be made internet gws"), - "badrequest", - ), - ) - return - } - err = logic.ValidateInetGwReq(node, request, false) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - logic.SetInternetGw(&node, request) - if servercfg.IsPro { - if _, exists := logic.FailOverExists(node.Network); exists { - go func() { - logic.ResetFailedOverPeer(&node) - mq.PublishPeerUpdate(false) - }() - } - go func() { - logic.ResetAutoRelayedPeer(&node) - mq.PublishPeerUpdate(false) - }() - - } - if node.IsGw && node.IngressDNS == "" { - node.IngressDNS = "1.1.1.1" - } - err = logic.UpsertNode(&node) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - apiNode := node.ConvertToAPINode() - logger.Log( - 1, - r.Header.Get("user"), - "created ingress gateway on node", - nodeid, - "on network", - netid, - ) - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(apiNode) - go mq.PublishPeerUpdate(false) -} - -func updateInternetGw(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - w.Header().Set("Content-Type", "application/json") - nodeid := params["nodeid"] - netid := params["network"] - node, err := logic.ValidateParams(nodeid, netid) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - var request models.InetNodeReq - err = json.NewDecoder(r.Body).Decode(&request) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - if !node.IsInternetGateway { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("node is not a internet gw"), "badrequest"), - ) - return - } - err = logic.ValidateInetGwReq(node, request, true) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - logic.UnsetInternetGw(&node) - logic.SetInternetGw(&node, request) - err = logic.UpsertNode(&node) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - apiNode := node.ConvertToAPINode() - logger.Log( - 1, - r.Header.Get("user"), - "created ingress gateway on node", - nodeid, - "on network", - netid, - ) - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(apiNode) - go mq.PublishPeerUpdate(false) -} - -func deleteInternetGw(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - w.Header().Set("Content-Type", "application/json") - nodeid := params["nodeid"] - netid := params["network"] - node, err := logic.ValidateParams(nodeid, netid) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - - logic.UnsetInternetGw(&node) - err = logic.UpsertNode(&node) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - apiNode := node.ConvertToAPINode() - logger.Log( - 1, - r.Header.Get("user"), - "created ingress gateway on node", - nodeid, - "on network", - netid, - ) - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(apiNode) - go mq.PublishPeerUpdate(false) -} diff --git a/controllers/legacy.go b/controllers/legacy.go deleted file mode 100644 index 18486b6a5..000000000 --- a/controllers/legacy.go +++ /dev/null @@ -1,38 +0,0 @@ -package controller - -import ( - "net/http" - - "github.com/gorilla/mux" - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic" -) - -func legacyHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/legacy/nodes", logic.SecurityCheck(true, http.HandlerFunc(wipeLegacyNodes))). - Methods(http.MethodDelete) - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(createInternetGw))). - Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(updateInternetGw))). - Methods(http.MethodPut) - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(deleteInternetGw))). - Methods(http.MethodDelete) -} - -// @Summary Delete all legacy nodes from DB. -// @Router /api/v1/legacy/nodes [delete] -// @Tags Nodes -// @Security oauth -// @Produce json -// @Success 200 {string} string "Wiped all legacy nodes." -// @Failure 400 {object} models.ErrorResponse -func wipeLegacyNodes(w http.ResponseWriter, r *http.Request) { - // Set header - w.Header().Set("Content-Type", "application/json") - if err := logic.RemoveAllLegacyNodes(); err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - logger.Log(0, "error occurred when removing legacy nodes", err.Error()) - } - logger.Log(0, r.Header.Get("user"), "wiped legacy nodes") - logic.ReturnSuccessResponse(w, r, "wiped all legacy nodes") -} diff --git a/controllers/migrate.go b/controllers/migrate.go deleted file mode 100644 index 36028d969..000000000 --- a/controllers/migrate.go +++ /dev/null @@ -1,218 +0,0 @@ -package controller - -import ( - "encoding/json" - "fmt" - "net" - "net/http" - "time" - - "github.com/google/uuid" - "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/mq" - "github.com/gravitl/netmaker/schema" - "github.com/gravitl/netmaker/servercfg" - "golang.org/x/crypto/bcrypt" - "golang.org/x/exp/slog" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -// @Summary Used to migrate a legacy node. -// @Router /api/v1/nodes/migrate [put] -// @Tags Nodes -// @Security oauth -// @Accept json -// @Produce json -// @Param body body models.MigrationData true "Migration data" -// @Success 200 {object} models.HostPull -// @Failure 400 {object} models.ErrorResponse -func migrate(w http.ResponseWriter, r *http.Request) { - data := models.MigrationData{} - host := schema.Host{} - node := models.Node{} - var nodes []models.Node - server := models.ServerConfig{} - err := json.NewDecoder(r.Body).Decode(&data) - if err != nil { - logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - for i, legacy := range data.LegacyNodes { - record, err := database.FetchRecord(database.NODES_TABLE_NAME, legacy.ID) - if err != nil { - slog.Error("legacy node not found", "error", err) - logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("legacy node not found %w", err), "badrequest")) - return - } - var legacyNode models.LegacyNode - if err = json.Unmarshal([]byte(record), &legacyNode); err != nil { - slog.Error("decoding legacy node", "error", err) - logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("decode legacy node %w", err), "badrequest")) - return - } - if err := bcrypt.CompareHashAndPassword([]byte(legacyNode.Password), []byte(legacy.Password)); err != nil { - slog.Error("legacy node invalid password", "error", err) - logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid password %w", err), "unauthorized")) - return - } - if i == 0 { - host, node = convertLegacyHostNode(legacy) - host.Name = data.HostName - host.HostPass = data.Password - host.OS = data.OS - if err := logic.CreateHost(&host); err != nil { - slog.Error("create host", "error", err) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - server = logic.GetServerInfo() - key, keyErr := logic.RetrievePublicTrafficKey() - if keyErr != nil { - slog.Error("retrieving traffickey", "error", keyErr) - logic.ReturnErrorResponse(w, r, logic.FormatError(keyErr, "internal")) - return - } - server.TrafficKey = key - } else { - node = convertLegacyNode(legacyNode, host.ID) - } - if err := logic.UpsertNode(&node); err != nil { - slog.Error("update node", "error", err) - continue - } - host.Nodes = append(host.Nodes, node.ID.String()) - - nodes = append(nodes, node) - } - if err := logic.UpsertHost(&host); err != nil { - slog.Error("save host", "error", err) - } - go mq.PublishPeerUpdate(false) - response := models.HostPull{ - Host: host, - Nodes: nodes, - ServerConfig: server, - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(&response) - - slog.Info("migrated nodes") - // check for gateways - for _, node := range data.LegacyNodes { - if node.IsEgressGateway == "yes" { - egressGateway := models.EgressGatewayRequest{ - NodeID: node.ID, - Ranges: node.EgressGatewayRanges, - NatEnabled: node.EgressGatewayNatEnabled, - } - if _, err := logic.CreateEgressGateway(egressGateway); err != nil { - logger.Log(0, "error creating egress gateway for node", node.ID, err.Error()) - } - } - if node.IsIngressGateway == "yes" { - ingressGateway := models.IngressRequest{} - ingressNode, err := logic.CreateIngressGateway(node.Network, node.ID, ingressGateway) - if err != nil { - logger.Log(0, "error creating ingress gateway for node", node.ID, err.Error()) - } - go func() { - if err := mq.NodeUpdate(&ingressNode); err != nil { - slog.Error("error publishing node update to node", "node", ingressNode.ID, "error", err) - } - }() - } - } -} - -func convertLegacyHostNode(legacy models.LegacyNode) (schema.Host, models.Node) { - //convert host - host := schema.Host{} - host.ID = uuid.New() - host.IPForwarding = models.ParseBool(legacy.IPForwarding) - host.AutoUpdate = logic.AutoUpdateEnabled() - host.Interface = "netmaker" - host.ListenPort = int(legacy.ListenPort) - if host.ListenPort == 0 { - host.ListenPort = 51821 - } - host.MTU = int(legacy.MTU) - pubKey, _ := wgtypes.ParseKey(legacy.PublicKey) - host.PublicKey = schema.WgKey{Key: pubKey} - host.MacAddress = net.HardwareAddr(legacy.MacAddress) - host.TrafficKeyPublic = legacy.TrafficKeys.Mine - host.Nodes = append([]string{}, legacy.ID) - host.Interfaces = legacy.Interfaces - //host.DefaultInterface = legacy.Defaul - host.EndpointIP = net.ParseIP(legacy.Endpoint) - host.IsDocker = models.ParseBool(legacy.IsDocker) - host.IsK8S = models.ParseBool(legacy.IsK8S) - host.IsStaticPort = models.ParseBool(legacy.IsStatic) - host.IsStatic = models.ParseBool(legacy.IsStatic) - host.PersistentKeepalive = time.Duration(legacy.PersistentKeepalive) * time.Second - if host.PersistentKeepalive == 0 { - host.PersistentKeepalive = models.DefaultPersistentKeepAlive - } - - node := convertLegacyNode(legacy, host.ID) - return host, node -} - -func convertLegacyNode(legacy models.LegacyNode, hostID uuid.UUID) models.Node { - //convert node - node := models.Node{} - node.ID, _ = uuid.Parse(legacy.ID) - node.HostID = hostID - node.Network = legacy.Network - valid4 := true - valid6 := true - _, cidr4, err := net.ParseCIDR(legacy.NetworkSettings.AddressRange) - if err != nil { - valid4 = false - slog.Warn("parsing address range", "error", err) - } else { - node.NetworkRange = *cidr4 - } - _, cidr6, err := net.ParseCIDR(legacy.NetworkSettings.AddressRange6) - if err != nil { - valid6 = false - slog.Warn("parsing address range6", "error", err) - } else { - node.NetworkRange6 = *cidr6 - } - node.Server = servercfg.GetServer() - node.Connected = models.ParseBool(legacy.Connected) - if valid4 { - node.Address = net.IPNet{ - IP: net.ParseIP(legacy.Address), - Mask: cidr4.Mask, - } - } - if valid6 { - node.Address6 = net.IPNet{ - IP: net.ParseIP(legacy.Address6), - Mask: cidr6.Mask, - } - } - node.Action = models.NODE_NOOP - node.LocalAddress = net.IPNet{ - IP: net.ParseIP(legacy.LocalAddress), - } - node.IsEgressGateway = models.ParseBool(legacy.IsEgressGateway) - node.EgressGatewayRanges = legacy.EgressGatewayRanges - node.IsIngressGateway = models.ParseBool(legacy.IsIngressGateway) - node.IsRelayed = false - node.IsRelay = false - node.RelayedNodes = []string{} - node.LastModified = time.Now().UTC() - node.ExpirationDateTime = time.Unix(legacy.ExpirationDateTime, 0) - node.EgressGatewayNatEnabled = models.ParseBool(legacy.EgressGatewayNatEnabled) - node.EgressGatewayRequest = legacy.EgressGatewayRequest - node.IngressGatewayRange = legacy.IngressGatewayRange - node.IngressGatewayRange6 = legacy.IngressGatewayRange6 - node.OwnerID = legacy.OwnerID - return node -} diff --git a/controllers/network.go b/controllers/network.go index df155cadc..f45278366 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -1,6 +1,7 @@ package controller import ( + "context" "encoding/json" "errors" "fmt" @@ -9,6 +10,8 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" "golang.org/x/exp/slog" @@ -17,7 +20,6 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" - "github.com/gravitl/netmaker/servercfg" ) func networkHandlers(r *mux.Router) { @@ -217,8 +219,6 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) { go logic.DeleteNetworkRoles(network) go logic.DeleteAllNetworkTags(schema.NetworkID(network)) go logic.DeleteNetworkPolicies(schema.NetworkID(network)) - //delete network from allocated ip map - go logic.RemoveNetworkFromAllocatedIpMap(network) go func() { <-doneCh mq.PublishPeerUpdate(true) @@ -226,16 +226,13 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) { for _, node := range networkNodes { node := node node.PendingDelete = true - node.Action = models.NODE_DELETE + node.Action = schema.NODE_DELETE if err := mq.NodeUpdate(&node); err != nil { slog.Error("error publishing node update to node", "node", node.ID, "error", err) } } _ = logic.DeleteNetworkNameservers(network) - if servercfg.IsDNSMode() { - logic.SetDNS() - } }() logic.LogEvent(&models.Event{ Action: schema.Delete, @@ -359,7 +356,6 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { logic.CreateDefaultNetworkRolesAndGroups(schema.NetworkID(network.Name)) logic.CreateDefaultAclNetworkPolicies(schema.NetworkID(network.Name)) logic.CreateDefaultTags(schema.NetworkID(network.Name)) - logic.AddNetworkToAllocatedIpMap(network.Name) logic.CreateFallbackNameserver(network.Name) if featureFlags.EnableOverlappingEgressRanges { if err := logic.AllocateUniqueVNATPool(&network); err != nil { @@ -371,65 +367,21 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { go func() { defaultHosts := logic.GetDefaultHosts() for i := range defaultHosts { - currHost := &defaultHosts[i] - newNode, err := logic.UpdateHostNetwork(currHost, network.Name, true) + host := &defaultHosts[i] + newNode, err := orchestrator.GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, &network) if err != nil { logger.Log( 0, r.Header.Get("user"), "failed to add host to network:", - currHost.ID.String(), + host.ID.String(), network.Name, err.Error(), ) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - logger.Log(1, "added new node", newNode.ID.String(), "to host", currHost.Name) - if len(currHost.Nodes) == 1 { - if err = mq.HostUpdate(&models.HostUpdate{ - Action: models.RequestPull, - Host: *currHost, - Node: *newNode, - }); err != nil { - logger.Log( - 0, - r.Header.Get("user"), - "failed to add host to network:", - currHost.ID.String(), - network.Name, - err.Error(), - ) - } - } else { - if err = mq.HostUpdate(&models.HostUpdate{ - Action: models.JoinHostToNetwork, - Host: *currHost, - Node: *newNode, - }); err != nil { - logger.Log( - 0, - r.Header.Get("user"), - "failed to add host to network:", - currHost.ID.String(), - network.Name, - err.Error(), - ) - } - } - - // make host failover - logic.CreateFailOver(*newNode) - // make host remote access gateway - logic.CreateIngressGateway(network.Name, newNode.ID.String(), models.IngressRequest{}) - logic.CreateRelay(models.RelayRequest{ - NodeID: newNode.ID.String(), - NetID: network.Name, - }) - } - // send peer updates - if err = mq.PublishPeerUpdate(false); err != nil { - logger.Log(1, "failed to publish peer update for default hosts after network is added") + logger.Log(1, "added new node", newNode.ID, "to host", host.Name) } }() logic.LogEvent(&models.Event{ diff --git a/controllers/network_test.go b/controllers/network_test.go index 2276fbdb6..9201b71c2 100644 --- a/controllers/network_test.go +++ b/controllers/network_test.go @@ -9,13 +9,11 @@ import ( "github.com/gravitl/netmaker/schema" "gorm.io/gorm" - "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/stretchr/testify/assert" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) type NetworkValidationTestCase struct { @@ -180,28 +178,6 @@ func TestValidateNetwork(t *testing.T) { } } -func TestIpv6Network(t *testing.T) { - //these seem to work but not sure it the tests are really testing the functionality - - os.Setenv("MASTER_KEY", "secretkey") - deleteAllNetworks() - createNet() - createNetDualStack() - network := &schema.Network{Name: "skynet6"} - err := network.Get(db.WithContext(context.TODO())) - t.Run("Test Network Create IPv6", func(t *testing.T) { - assert.Nil(t, err) - assert.Equal(t, network.AddressRange6, "fde6:be04:fa5e:d076::/64") - }) - node1 := createNodeWithParams("skynet6", "") - createNetHost() - nodeErr := logic.AssociateNodeToHost(node1, &netHost) - t.Run("Test node on network IPv6", func(t *testing.T) { - assert.Nil(t, nodeErr) - assert.Equal(t, "fde6:be04:fa5e:d076::1", node1.Address6.IP.String()) - }) -} - func deleteAllNetworks() { deleteAllNodes() _networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) @@ -228,26 +204,3 @@ func createNetv1(netId string) { logic.CreateNetwork(&network) } } - -func createNetDualStack() { - var network schema.Network - network.Name = "skynet6" - network.AddressRange = "10.1.2.0/24" - network.AddressRange6 = "fde6:be04:fa5e:d076::/64" - err := (&schema.Network{Name: "skynet6"}).Get(db.WithContext(context.TODO())) - if err != nil { - logic.CreateNetwork(&network) - } -} - -func createNetHost() { - k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=") - netHost = schema.Host{ - ID: uuid.New(), - PublicKey: schema.WgKey{Key: k.PublicKey()}, - HostPass: "password", - OS: "linux", - Name: "nethost", - } - _ = logic.CreateHost(&netHost) -} diff --git a/controllers/node.go b/controllers/node.go index ed00b69b0..ab1e1f2c8 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -3,22 +3,27 @@ package controller import ( "context" "encoding/json" + "errors" "fmt" "net/http" + "strconv" "strings" "time" "github.com/gorilla/mux" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" + dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" "golang.org/x/exp/slog" + "gorm.io/gorm" ) var hostIDHeader = "host-id" @@ -27,6 +32,7 @@ func nodeHandlers(r *mux.Router) { r.HandleFunc("/api/nodes", logic.SecurityCheck(true, http.HandlerFunc(getAllNodes))).Methods(http.MethodGet) r.HandleFunc("/api/nodes/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodes))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nodes/{network}", logic.SecurityCheck(true, http.HandlerFunc(listNetworkNodes))).Methods(http.MethodGet) r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(getNode))).Methods(http.MethodGet) r.HandleFunc("/api/nodes/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(updateNode))).Methods(http.MethodPut) r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(deleteNode))).Methods(http.MethodDelete) @@ -38,7 +44,6 @@ func nodeHandlers(r *mux.Router) { r.HandleFunc("/api/v1/nodes/{network}/bulk", logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteNodes))).Methods(http.MethodDelete) r.HandleFunc("/api/v1/nodes/{network}/bulk/status", logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateNodeStatus))).Methods(http.MethodPut) r.HandleFunc("/api/v1/nodes/{network}/status", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodeStatus))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/nodes/migrate", migrate).Methods(http.MethodPost) } func authenticate(response http.ResponseWriter, request *http.Request) { @@ -193,6 +198,118 @@ func AuthorizeHost( } } +// @Summary List all nodes in the network +// @Router /api/v1/nodes/{network} [get] +// @Tags Nodes +// @Security oauth +// @Produce json +// @Param network path string true "Network ID" +// @Param os query []string false "Filter by OS" Enums(windows, linux, darwin) +// @Param device_type query string false "Device Type" Enums(gw, igw, gw_assigned, gw_unassigned) +// @Param page query int false "Page number" +// @Param per_page query int false "Items per page" +// @Success 200 {array} models.ApiNode +// @Failure 500 {object} models.ErrorResponse +func listNetworkNodes(w http.ResponseWriter, r *http.Request) { + networkName := mux.Vars(r)["network"] + + var osFilters []interface{} + for _, filter := range r.URL.Query()["os"] { + osFilters = append(osFilters, filter) + } + + deviceType := r.URL.Query().Get("device_type") + + var page, pageSize int + page, _ = strconv.Atoi(r.URL.Query().Get("page")) + if page == 0 { + page = 1 + } + + pageSize, _ = strconv.Atoi(r.URL.Query().Get("per_page")) + if pageSize < 1 || pageSize > 100 { + pageSize = 10 + } + + network := &schema.Network{ + Name: networkName, + } + err := network.Get(r.Context()) + if err != nil { + errType := logic.Internal + if errors.Is(err, gorm.ErrRecordNotFound) { + errType = logic.BadReq + } + + err = fmt.Errorf("failed to fetch nodes in network %s: error fetching network: %v", networkName, err) + logger.Log(0, r.Header.Get("user"), err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, errType)) + return + } + + var filters, options []dbtypes.Option + filters = append(filters, dbtypes.WithFilter("network_id", network.ID)) + filters = append(filters, dbtypes.WithJoin("Host", dbtypes.WithFilter("os", osFilters...))) + + if deviceType != "" { + switch deviceType { + case "gw": + filters = append(filters, dbtypes.WithFilter("is_gateway", true)) + case "igw": + filters = append(filters, dbtypes.WithFilter("is_internet_gateway", true)) + case "gw_assigned": + filters = append(filters, dbtypes.WithNotFilter("relaying_node_id", nil)) + case "gw_unassigned": + filters = append(filters, dbtypes.WithFilter("relaying_node_id", nil)) + } + } + + options = append(options, filters...) + options = append(options, dbtypes.InAscOrder(fmt.Sprintf("%s.created_at", (&schema.Node{}).TableName()))) + options = append(options, dbtypes.WithPagination(page, pageSize)) + + _nodes, err := (&schema.Node{}).ListAll(r.Context(), options...) + if err != nil { + err = fmt.Errorf("failed to fetch nodes in network %s: error fetching nodes: %v", networkName, err) + logger.Log(0, r.Header.Get("user"), err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + + nodes := make([]models.NodeWithHost, 0, len(_nodes)) + for _, _node := range _nodes { + var node models.NodeWithHost + _node.Status = logic.GetNodeCheckInStatus(&_node) + node.Fill(&_node) + nodes = append(nodes, node) + } + + logger.Log(2, r.Header.Get("user"), "fetched nodes in network", networkName) + + total, err := (&schema.Node{}).Count(r.Context(), filters...) + if err != nil { + err = fmt.Errorf("failed to fetch nodes in network %s: error constructing page: %v", networkName, err) + logger.Log(0, r.Header.Get("user"), err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + + totalPages := (total + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + + response := models.PaginatedResponse{ + Data: nodes, + Page: page, + PerPage: pageSize, + Total: total, + TotalPages: totalPages, + } + + logic.ReturnSuccessResponseWithJson(w, r, response, "fetched network nodes") +} + // @Summary Gets all nodes associated with network including pending nodes // @Router /api/nodes/{network} [get] // @Tags Nodes @@ -518,12 +635,29 @@ func updateNode(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - err = logic.ValidateNodeIp(¤tNode, &newData) + + network := &schema.Network{Name: currentNode.Network} + err = network.Get(db.WithContext(context.TODO())) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } + if currentNode.Address.IP != nil && currentNode.Address.String() != newData.Address { + if !orchestrator.GetRepository().NetworkOrchestrator().IsIPv4Unique(r.Context(), network, newData.Address) { + err = errors.New("ip specified is already allocated: " + newData.Address) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + } + if currentNode.Address6.IP != nil && currentNode.Address6.String() != newData.Address6 { + if !orchestrator.GetRepository().NetworkOrchestrator().IsIPv6Unique(r.Context(), network, newData.Address6) { + err = errors.New("ip specified is already allocated: " + newData.Address6) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + } + if !servercfg.IsPro { newData.AdditionalRagIps = []string{} } @@ -689,9 +823,6 @@ func updateNode(w http.ResponseWriter, r *http.Request) { if servercfg.IsPro && newNode.AutoAssignGateway { mq.HostUpdate(&models.HostUpdate{Action: models.CheckAutoAssignGw, Host: *host, Node: *newNode}) } - if servercfg.IsDNSMode() { - logic.SetDNS() - } if !newNode.Connected { metrics, err := logic.GetMetrics(newNode.ID.String()) if err == nil { @@ -863,22 +994,32 @@ func bulkUpdateNodeStatus(w http.ResponseWriter, r *http.Request) { logic.ReturnAcceptedResponse(w, r, fmt.Sprintf("bulk %s of %d node(s) accepted", eventAction, len(req.IDs))) go func() { - updated := 0 + var nodeIDs []interface{} + // filter out invalid node IDs. for _, nodeID := range req.IDs { - node, err := logic.GetNodeByID(nodeID) - if err != nil { - slog.Error("bulk node status: node not found", "id", nodeID, "error", err) - continue + node := &schema.Node{ + ID: nodeID, } - if node.Connected == req.Connected || node.Network != network { - continue - } - newNode := node - newNode.Connected = req.Connected - if err := logic.UpdateNode(&node, &newNode); err != nil { - slog.Error("bulk node status: failed to update node", "id", nodeID, "error", err) - continue + exists, err := node.Exists(db.WithContext(context.TODO())) + if err == nil && exists { + nodeIDs = append(nodeIDs, nodeID) } + } + + nodeUpdate := &schema.Node{ + Connected: req.Connected, + } + err := nodeUpdate.UpdateConnectedStatus( + db.WithContext(context.TODO()), + dbtypes.WithFilter("id", nodeIDs...), + ) + if err != nil { + slog.Error("bulk node status: failed to update nodes connected status", "error", err) + return + } + + for i := range nodeIDs { + nodeID := nodeIDs[i].(string) if !req.Connected { metrics, err := logic.GetMetrics(nodeID) if err == nil { @@ -898,18 +1039,16 @@ func bulkUpdateNodeStatus(w http.ResponseWriter, r *http.Request) { }, TriggeredBy: user, Target: models.Subject{ - ID: node.ID.String(), - Name: node.ID.String(), + ID: nodeID, + Name: nodeID, Type: schema.NodeSub, }, NetworkID: schema.NetworkID(network), Origin: schema.Dashboard, }) - updated++ - } - if updated > 0 { - mq.PublishPeerUpdate(false) } - slog.Info("bulk node status completed", "action", eventAction, "updated", updated, "total", len(req.IDs)) + + mq.PublishPeerUpdate(false) + slog.Info("bulk node status completed", "action", eventAction, "total", len(req.IDs)) }() } diff --git a/controllers/node_test.go b/controllers/node_test.go index 2d15791e8..b251cee19 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -9,7 +9,6 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" - "github.com/gravitl/netmaker/servercfg" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -34,7 +33,7 @@ func TestGetNetworkNodes(t *testing.T) { createTestNode() node, err := logic.GetNetworkNodes("skynet") assert.Nil(t, err) - assert.NotEqual(t, []models.LegacyNode(nil), node) + assert.NotEqual(t, []models.Node(nil), node) }) } @@ -50,9 +49,6 @@ func TestValidateEgressGateway(t *testing.T) { } func deleteAllNodes() { - if servercfg.CacheEnabled() { - logic.ClearNodeCache() - } database.DeleteAllRecords(database.NODES_TABLE_NAME) } diff --git a/controllers/user.go b/controllers/user.go index aa4c2a046..2e1bef2c8 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -1742,9 +1742,6 @@ func deleteUser(w http.ResponseWriter, r *http.Request) { } _ = logic.DeleteUserInvite(user.Username) mq.PublishPeerUpdate(false) - if servercfg.IsDNSMode() { - logic.SetDNS() - } }() logger.Log(1, username, "was deleted") json.NewEncoder(w).Encode(params["username"] + " deleted.") @@ -1869,9 +1866,6 @@ func bulkDeleteUsers(w http.ResponseWriter, r *http.Request) { } if deleted > 0 { mq.PublishPeerUpdate(false) - if servercfg.IsDNSMode() { - logic.SetDNS() - } } slog.Info("bulk user delete completed", "deleted", deleted, "total", len(req.IDs)) }() diff --git a/database/database.go b/database/database.go index adf73105b..e0e8fbb93 100644 --- a/database/database.go +++ b/database/database.go @@ -98,7 +98,6 @@ const ( ) var Tables = []string{ - NODES_TABLE_NAME, CERTS_TABLE_NAME, DELETED_NODES_TABLE_NAME, DNS_TABLE_NAME, @@ -127,6 +126,7 @@ var Tables = []string{ HOSTS_TABLE_NAME, PENDING_USERS_TABLE_NAME, USER_INVITES_TABLE_NAME, + NODES_TABLE_NAME, } func getCurrentDB() map[string]interface{} { diff --git a/db/types/options.go b/db/types/options.go index 428186f57..0f3000949 100644 --- a/db/types/options.go +++ b/db/types/options.go @@ -5,10 +5,40 @@ import ( "strings" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type Option func(db *gorm.DB) *gorm.DB +func WithPreloads(associations ...string) Option { + return func(db *gorm.DB) *gorm.DB { + for _, association := range associations { + db = db.Preload(association) + } + return db + } +} + +func WithAllPreloads() Option { + return func(db *gorm.DB) *gorm.DB { + return db.Preload(clause.Associations) + } +} + +// WithJoin joins the given association and optionally scopes the join with conditions. +// Conditions are applied only to the join clause, not the outer query. +func WithJoin(association string, conditions ...Option) Option { + return func(db *gorm.DB) *gorm.DB { + // NewDB: true creates a fresh *gorm.DB with no inherited clauses (no WHERE, ORDER BY, etc.) + // This ensures conditions passed here are scoped only to the JOIN, not the outer query. + condDB := db.Session(&gorm.Session{NewDB: true}) + for _, condition := range conditions { + condDB = condition(condDB) + } + return db.Joins(association, condDB) + } +} + func WithPagination(page, pageSize int) Option { return func(db *gorm.DB) *gorm.DB { if page < 1 { @@ -34,6 +64,9 @@ func WithFilter(field string, value ...interface{}) Option { } if len(value) == 1 { + if value[0] == nil { + return db.Where(fmt.Sprintf("%s IS NULL", db.Statement.Quote(field))) + } return db.Where(fmt.Sprintf("%s = ?", db.Statement.Quote(field)), value[0]) } @@ -41,6 +74,24 @@ func WithFilter(field string, value ...interface{}) Option { } } +// WithNotFilter applies a WHERE NOT clause for the given column. +// IMPORTANT: `field` MUST be a trusted, hardcoded column name. +// NEVER pass user-supplied strings as `field`. +func WithNotFilter(field string, value ...interface{}) Option { + return func(db *gorm.DB) *gorm.DB { + if len(value) == 0 { + return db + } + if len(value) == 1 { + if value[0] == nil { + return db.Where(fmt.Sprintf("%s IS NOT NULL", db.Statement.Quote(field))) + } + return db.Where(fmt.Sprintf("%s != ?", db.Statement.Quote(field)), value[0]) + } + return db.Where(fmt.Sprintf("%s NOT IN ?", field), value) + } +} + // WithSearchQuery applies a WHERE clause searching `q` across multiple text fields using OR. // Uses LOWER() for case-insensitive matching across SQLite and PostgreSQL. // IMPORTANT: `fields` MUST be trusted, hardcoded column names. diff --git a/functions/local.go b/functions/local.go deleted file mode 100644 index a65bcd345..000000000 --- a/functions/local.go +++ /dev/null @@ -1,52 +0,0 @@ -package functions - -import ( - "os" - - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic" -) - -// LINUX_APP_DATA_PATH - linux path -const LINUX_APP_DATA_PATH = "/etc/netmaker" - -// FileExists - checks if file exists -func FileExists(f string) bool { - info, err := os.Stat(f) - if os.IsNotExist(err) { - return false - } - return !info.IsDir() -} - -// SetDNSDir - sets the dns directory of the system -func SetDNSDir() error { - dir, err := os.Getwd() - if err != nil { - return err - } - - err = os.MkdirAll(dir+"/config/dnsconfig", 0744) - if err != nil { - logger.Log(0, "couldnt find or create /config/dnsconfig") - return err - } - - err = logic.SetCorefile(".") - if err != nil { - logger.Log(0, err.Error()) - } - _, err = os.Stat(dir + "/config/dnsconfig/netmaker.hosts") - if os.IsNotExist(err) { - _, err = os.Create(dir + "/config/dnsconfig/netmaker.hosts") - if err != nil { - logger.Log(0, err.Error()) - } - } - return nil -} - -// GetNetmakerPath - gets netmaker path locally -func GetNetmakerPath() string { - return LINUX_APP_DATA_PATH -} diff --git a/logic/dns.go b/logic/dns.go index 3d193798c..707c53aa7 100644 --- a/logic/dns.go +++ b/logic/dns.go @@ -108,11 +108,6 @@ func DeleteNetworkNameservers(networkID string) error { }).Delete(db.WithContext(context.TODO())) } -// SetDNS - sets the dns on file -func SetDNS() error { - return nil -} - // GetDNS - gets the DNS of a current network func GetDNS(network string) ([]models.DNSEntry, error) { diff --git a/logic/extpeers.go b/logic/extpeers.go index 54649a055..fc6e37b98 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -162,13 +162,6 @@ func DeleteExtClient(network string, clientid string, isUpdate bool) error { return err } if servercfg.CacheEnabled() { - // recycle ip address - if extClient.Address != "" { - RemoveIpFromAllocatedIpMap(network, extClient.Address) - } - if extClient.Address6 != "" { - RemoveIpFromAllocatedIpMap(network, extClient.Address6) - } deleteExtClientFromCache(key) } if !isUpdate && extClient.RemoteAccessClientID != "" { @@ -276,94 +269,6 @@ func GetExtClient(clientid string, network string) (models.ExtClient, error) { return extclient, err } -// GetGwExtclients - return all ext clients attached to the passed gw id -func GetGwExtclients(nodeID, network string) []models.ExtClient { - gwClients := []models.ExtClient{} - clients, err := GetNetworkExtClients(network) - if err != nil { - return gwClients - } - for _, client := range clients { - if client.IngressGatewayID == nodeID { - gwClients = append(gwClients, client) - } - } - return gwClients -} - -// GetExtClient - gets a single ext client on a network -func GetExtClientByPubKey(publicKey string, network string) (*models.ExtClient, error) { - netClients, err := GetNetworkExtClients(network) - if err != nil { - return nil, err - } - for i := range netClients { - ec := netClients[i] - if ec.PublicKey == publicKey { - return &ec, nil - } - } - - return nil, fmt.Errorf("no client found") -} - -// CreateExtClient - creates and saves an extclient -func CreateExtClient(extclient *models.ExtClient) error { - // lock because we may need unique IPs and having it concurrent makes parallel calls result in same "unique" IPs - addressLock.Lock() - defer addressLock.Unlock() - - if len(extclient.PublicKey) == 0 { - privateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return err - } - extclient.PrivateKey = privateKey.String() - extclient.PublicKey = privateKey.PublicKey().String() - } else if len(extclient.PrivateKey) == 0 && len(extclient.PublicKey) > 0 { - extclient.PrivateKey = "[ENTER PRIVATE KEY]" - } - if extclient.ExtraAllowedIPs == nil { - extclient.ExtraAllowedIPs = []string{} - } - - parentNetwork := &schema.Network{Name: extclient.Network} - err := parentNetwork.Get(db.WithContext(context.TODO())) - if err != nil { - return err - } - if extclient.Address == "" { - if parentNetwork.AddressRange != "" { - newAddress, err := UniqueAddress(extclient.Network, true) - if err != nil { - return err - } - extclient.Address = newAddress.String() - } - } - - if extclient.Address6 == "" { - if parentNetwork.AddressRange6 != "" { - addr6, err := UniqueAddress6(extclient.Network, true) - if err != nil { - return err - } - extclient.Address6 = addr6.String() - } - } - - if extclient.ClientID == "" { - extclient.ClientID, err = GenerateNodeName(extclient.Network) - if err != nil { - return err - } - } - - extclient.LastModified = time.Now().Unix() - return SaveExtClient(extclient) -} - -// GenerateNodeName - generates a random node name func GenerateNodeName(network string) (string, error) { seed := time.Now().UTC().UnixNano() nameGenerator := namegenerator.NewNameGenerator(seed) @@ -403,12 +308,6 @@ func SaveExtClient(extclient *models.ExtClient) error { } if servercfg.CacheEnabled() { storeExtClientInCache(key, *extclient) - if extclient.Address != "" { - AddIpToAllocatedIpMap(extclient.Network, net.ParseIP(extclient.Address)) - } - if extclient.Address6 != "" { - AddIpToAllocatedIpMap(extclient.Network, net.ParseIP(extclient.Address6)) - } } return SetNetworkNodesLastModified(extclient.Network) @@ -501,7 +400,7 @@ func GetAllExtClients() ([]models.ExtClient, error) { // GetAllExtClientsWithStatus - returns all external clients with // given status. -func GetAllExtClientsWithStatus(status models.NodeStatus) ([]models.ExtClient, error) { +func GetAllExtClientsWithStatus(status schema.NodeStatus) ([]models.ExtClient, error) { extClients, err := GetAllExtClients() if err != nil { return nil, err diff --git a/logic/gateway.go b/logic/gateway.go index ed5e3b97f..7b5b18446 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -10,7 +10,6 @@ import ( "context" - "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" @@ -30,51 +29,6 @@ func IsInternetGw(node models.Node) bool { return node.IsInternetGateway } -// GetInternetGateways - gets all the nodes that are internet gateways -func GetInternetGateways() ([]models.Node, error) { - nodes, err := GetAllNodes() - if err != nil { - return nil, err - } - igs := make([]models.Node, 0) - for _, node := range nodes { - if node.IsInternetGateway { - igs = append(igs, node) - } - } - return igs, nil -} - -// GetAllIngresses - gets all the nodes that are ingresses -func GetAllIngresses() ([]models.Node, error) { - nodes, err := GetAllNodes() - if err != nil { - return nil, err - } - ingresses := make([]models.Node, 0) - for _, node := range nodes { - if node.IsIngressGateway { - ingresses = append(ingresses, node) - } - } - return ingresses, nil -} - -// GetAllEgresses - gets all the nodes that are egresses -func GetAllEgresses() ([]models.Node, error) { - nodes, err := GetAllNodes() - if err != nil { - return nil, err - } - egresses := make([]models.Node, 0) - for _, node := range nodes { - if node.EgressDetails.IsEgressGateway { - egresses = append(egresses, node) - } - } - return egresses, nil -} - // CreateEgressGateway - creates an egress gateway func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, error) { node, err := GetNodeByID(gateway.NodeID) @@ -91,7 +45,7 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro if host.OS != "linux" { // support for other OS to be added return models.Node{}, errors.New(host.OS + " is unsupported for egress gateways") } - if host.FirewallInUse == models.FIREWALL_NONE { + if host.FirewallInUse == schema.FIREWALL_NONE { return models.Node{}, errors.New("please install iptables or nftables on the device") } if len(gateway.RangesWithMetric) == 0 && len(gateway.Ranges) > 0 { @@ -226,12 +180,7 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq node.IngressMTU = ingress.MTU } if servercfg.IsPro { - if _, exists := FailOverExists(node.Network); exists { - ResetFailedOverPeer(&node) - } - ResetAutoRelayedPeer(&node) - } node.SetLastModified() node.Metadata = ingress.Metadata @@ -349,7 +298,7 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool if err != nil { return err } - if inetHost.FirewallInUse == models.FIREWALL_NONE { + if inetHost.FirewallInUse == schema.FIREWALL_NONE { return errors.New("iptables or nftables needs to be installed") } if inetNode.InternetGwID != "" { @@ -364,7 +313,7 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool if err != nil { return err } - if clientNode.IsFailOver || clientNode.IsAutoRelay { + if clientNode.IsAutoRelay { return errors.New("failover node cannot be set to use internet gateway") } clientHost := &schema.Host{ @@ -389,9 +338,6 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool return fmt.Errorf("node %s is already using a internet gateway", clientHost.Name) } } - if clientNode.FailedOverBy != uuid.Nil { - ResetFailedOverPeer(&clientNode) - } if len(clientNode.AutoRelayedPeers) > 0 { ResetAutoRelayedPeer(&clientNode) } diff --git a/logic/hosts.go b/logic/hosts.go index a8df16777..4f8a424e1 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -49,7 +49,7 @@ const ( // GetAllHostsWithStatus - returns all hosts with at least one // node with given status. -func GetAllHostsWithStatus(status models.NodeStatus) ([]schema.Host, error) { +func GetAllHostsWithStatus(status schema.NodeStatus) ([]schema.Host, error) { hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO())) if err != nil { return nil, err @@ -84,16 +84,17 @@ func GetAllHostsAPI(hosts []schema.Host) []models.ApiHost { return apiHosts[:] } -func DoesHostExistinTheNetworkAlready(h *schema.Host, network schema.NetworkID) bool { - if len(h.Nodes) > 0 { - for _, nodeID := range h.Nodes { - node, err := GetNodeByID(nodeID) - if err == nil && node.Network == network.String() { - return true - } - } +func DoesHostExistInTheNetworkAlready(h *schema.Host, networkID schema.NetworkID) bool { + node := &schema.Node{ + HostID: h.ID.String(), + NetworkID: networkID.String(), } - return false + err := node.GetByHostAndNetwork(db.WithContext(context.TODO())) + if err != nil { + return false + } + + return true } // CreateHost - creates a host if not exist @@ -216,9 +217,6 @@ func UpdateHostFromClient(newHost, currHost *schema.Host) (isEndpointChanged, se slog.Error("failed to get node:", "id", node.ID, "error", err) continue } - if node.FailedOverBy != uuid.Nil { - ResetFailedOverPeer(&node) - } if len(node.AutoRelayedPeers) > 0 { ResetAutoRelayedPeer(&node) } @@ -287,7 +285,6 @@ func UpdateHostNode(h *schema.Host, newNode *models.Node) (publishDeletedNodeUpd } } publishPeerUpdate = true - ResetFailedOverPeer(newNode) ResetAutoRelayedPeer(newNode) return @@ -331,17 +328,7 @@ func RemoveHost(h *schema.Host, forceDelete bool) error { } } - err := h.Delete(db.WithContext(context.TODO())) - if err != nil { - return err - } - go func() { - if servercfg.IsDNSMode() { - SetDNS() - } - }() - - return nil + return h.Delete(db.WithContext(context.TODO())) } // UpdateHostNetwork - adds/deletes host from a network @@ -352,46 +339,11 @@ func UpdateHostNetwork(h *schema.Host, network string, add bool) (*models.Node, continue } if node.Network == network { - if !add { - return &node, nil - } else { - return &node, errors.New("host already part of network " + network) - } - } - } - if !add { - return nil, errors.New("host not part of the network " + network) - } else { - newNode := models.Node{} - newNode.Server = servercfg.GetServer() - newNode.Network = network - newNode.HostID = h.ID - if err := AssociateNodeToHost(&newNode, h); err != nil { - return nil, err + return &node, nil } - return &newNode, nil } -} -// AssociateNodeToHost - associates and creates a node with a given host -// should be the only way nodes get created as of 0.18 -func AssociateNodeToHost(n *models.Node, h *schema.Host) error { - if len(h.ID.String()) == 0 || h.ID == uuid.Nil { - return ErrInvalidHostID - } - n.HostID = h.ID - err := createNode(n) - if err != nil { - return err - } - currentHost := &schema.Host{ID: h.ID} - if err := currentHost.Get(db.WithContext(context.TODO())); err != nil { - return fmt.Errorf("failed to fetch host before node association: %w", err) - } - h.Nodes = currentHost.Nodes - h.HostPass = currentHost.HostPass - h.Nodes = append(h.Nodes, n.ID.String()) - return UpsertHost(h) + return nil, errors.New("host not part of the network " + network) } // DissasociateNodeFromHost - deletes a node and removes from host nodes diff --git a/logic/legacy.go b/logic/legacy.go deleted file mode 100644 index 5f858a3d5..000000000 --- a/logic/legacy.go +++ /dev/null @@ -1,46 +0,0 @@ -package logic - -import ( - "encoding/json" - - "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/models" -) - -// IsLegacyNode - checks if a node is legacy or not -func IsLegacyNode(nodeID string) bool { - record, err := database.FetchRecord(database.NODES_TABLE_NAME, nodeID) - if err != nil { - return false - } - var currentNode models.Node - var legacyNode models.LegacyNode - currentNodeErr := json.Unmarshal([]byte(record), ¤tNode) - legacyNodeErr := json.Unmarshal([]byte(record), &legacyNode) - return currentNodeErr != nil && legacyNodeErr == nil -} - -// CheckAndRemoveLegacyNode - checks for legacy node and removes -func CheckAndRemoveLegacyNode(nodeID string) bool { - if IsLegacyNode(nodeID) { - if err := database.DeleteRecord(database.NODES_TABLE_NAME, nodeID); err == nil { - return true - } - } - return false -} - -// RemoveAllLegacyNodes - fetches all legacy nodes from DB and removes -func RemoveAllLegacyNodes() error { - records, err := database.FetchRecords(database.NODES_TABLE_NAME) - if err != nil { - return err - } - for k := range records { - if CheckAndRemoveLegacyNode(k) { - logger.Log(0, "removed legacy node", k) - } - } - return nil -} diff --git a/logic/networks.go b/logic/networks.go index 5b2692fc3..229c48560 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -10,134 +10,16 @@ import ( "net" "sort" "strings" - "sync" "time" - "github.com/c-robinson/iplib" "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" - "github.com/gravitl/netmaker/servercfg" - "golang.org/x/exp/slog" "gorm.io/gorm" ) -var ( - networkCacheMutex = &sync.RWMutex{} - allocatedIpMap = make(map[string]map[string]net.IP) -) - -// SetAllocatedIpMap - set allocated ip map for networks -func SetAllocatedIpMap() error { - if !servercfg.CacheEnabled() { - return nil - } - logger.Log(0, "start setting up allocated ip map") - if allocatedIpMap == nil { - allocatedIpMap = map[string]map[string]net.IP{} - } - - currentNetworks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) - if err != nil { - return err - } - - for _, v := range currentNetworks { - pMap := map[string]net.IP{} - netName := v.Name - - //nodes - nodes, err := GetNetworkNodes(netName) - if err != nil { - slog.Error("could not load node for network", netName, "error", err.Error()) - } else { - for _, n := range nodes { - - if n.Address.IP != nil { - pMap[n.Address.IP.String()] = n.Address.IP - } - if n.Address6.IP != nil { - pMap[n.Address6.IP.String()] = n.Address6.IP - } - } - - } - - //extClients - extClients, err := GetNetworkExtClients(netName) - if err != nil { - slog.Error("could not load extClient for network", netName, "error", err.Error()) - } else { - for _, extClient := range extClients { - if extClient.Address != "" { - pMap[extClient.Address] = net.ParseIP(extClient.Address) - } - if extClient.Address6 != "" { - pMap[extClient.Address6] = net.ParseIP(extClient.Address6) - } - } - } - - allocatedIpMap[netName] = pMap - } - logger.Log(0, "setting up allocated ip map done") - return nil -} - -// ClearAllocatedIpMap - set allocatedIpMap to nil -func ClearAllocatedIpMap() { - if !servercfg.CacheEnabled() { - return - } - allocatedIpMap = nil -} - -func AddIpToAllocatedIpMap(networkName string, ip net.IP) { - if !servercfg.CacheEnabled() { - return - } - networkCacheMutex.Lock() - if m, ok := allocatedIpMap[networkName]; ok { - m[ip.String()] = ip - } - networkCacheMutex.Unlock() -} - -func RemoveIpFromAllocatedIpMap(networkName string, ip string) { - if !servercfg.CacheEnabled() { - return - } - networkCacheMutex.Lock() - if m, ok := allocatedIpMap[networkName]; ok { - delete(m, ip) - } - networkCacheMutex.Unlock() -} - -// AddNetworkToAllocatedIpMap - add network to allocated ip map when network is added -func AddNetworkToAllocatedIpMap(networkName string) { - //add new network to allocated ip map - if !servercfg.CacheEnabled() { - return - } - networkCacheMutex.Lock() - allocatedIpMap[networkName] = make(map[string]net.IP) - networkCacheMutex.Unlock() -} - -// RemoveNetworkFromAllocatedIpMap - remove network from allocated ip map when network is deleted -func RemoveNetworkFromAllocatedIpMap(networkName string) { - if !servercfg.CacheEnabled() { - return - } - networkCacheMutex.Lock() - delete(allocatedIpMap, networkName) - networkCacheMutex.Unlock() -} - // DeleteNetwork - deletes a network func DeleteNetwork(network string, force bool, done chan struct{}) error { @@ -479,243 +361,6 @@ func intersect(n1, n2 *net.IPNet) bool { return n2.Contains(n1.IP) || n1.Contains(n2.IP) } -// UniqueAddress - get a unique ipv4 address -func UniqueAddressCache(networkName string, reverse bool) (net.IP, error) { - add := net.IP{} - network := &schema.Network{Name: networkName} - err := network.Get(db.WithContext(context.TODO())) - if err != nil { - logger.Log(0, "UniqueAddressServer encountered an error") - return add, err - } - - if network.AddressRange == "" { - return add, fmt.Errorf("IPv4 not active on network %s", networkName) - } - //ensure AddressRange is valid - if _, _, err := net.ParseCIDR(network.AddressRange); err != nil { - logger.Log(0, "UniqueAddress encountered an error") - return add, err - } - net4 := iplib.Net4FromStr(network.AddressRange) - newAddrs := net4.FirstAddress() - - if reverse { - newAddrs = net4.LastAddress() - } - - networkCacheMutex.RLock() - ipAllocated := allocatedIpMap[networkName] - for { - if _, ok := ipAllocated[newAddrs.String()]; !ok { - networkCacheMutex.RUnlock() - return newAddrs, nil - } - if reverse { - newAddrs, err = net4.PreviousIP(newAddrs) - } else { - newAddrs, err = net4.NextIP(newAddrs) - } - if err != nil { - break - } - } - networkCacheMutex.RUnlock() - - return add, errors.New("ERROR: No unique addresses available. Check network subnet") -} - -// UniqueAddress - get a unique ipv4 address -func UniqueAddressDB(networkName string, reverse bool) (net.IP, error) { - add := net.IP{} - network := &schema.Network{Name: networkName} - err := network.Get(db.WithContext(context.TODO())) - if err != nil { - logger.Log(0, "UniqueAddressServer encountered an error") - return add, err - } - - if network.AddressRange == "" { - return add, fmt.Errorf("IPv4 not active on network %s", networkName) - } - //ensure AddressRange is valid - if _, _, err := net.ParseCIDR(network.AddressRange); err != nil { - logger.Log(0, "UniqueAddress encountered an error") - return add, err - } - net4 := iplib.Net4FromStr(network.AddressRange) - newAddrs := net4.FirstAddress() - - if reverse { - newAddrs = net4.LastAddress() - } - - for { - if IsIPUnique(networkName, newAddrs.String(), database.NODES_TABLE_NAME, false) && - IsIPUnique(networkName, newAddrs.String(), database.EXT_CLIENT_TABLE_NAME, false) { - return newAddrs, nil - } - if reverse { - newAddrs, err = net4.PreviousIP(newAddrs) - } else { - newAddrs, err = net4.NextIP(newAddrs) - } - if err != nil { - break - } - } - - return add, errors.New("ERROR: No unique addresses available. Check network subnet") -} - -// IsIPUnique - checks if an IP is unique -func IsIPUnique(network string, ip string, tableName string, isIpv6 bool) bool { - - isunique := true - if tableName == database.NODES_TABLE_NAME { - nodes, err := GetNetworkNodes(network) - if err != nil { - return isunique - } - for _, node := range nodes { - if isIpv6 { - if node.Address6.IP.String() == ip && node.Network == network { - return false - } - } else { - if node.Address.IP.String() == ip && node.Network == network { - return false - } - } - } - - } else if tableName == database.EXT_CLIENT_TABLE_NAME { - - extClients, err := GetNetworkExtClients(network) - if err != nil { - return isunique - } - for _, extClient := range extClients { // filter - if isIpv6 { - if (extClient.Address6 == ip) && extClient.Network == network { - return false - } - - } else { - if (extClient.Address == ip) && extClient.Network == network { - return false - } - } - } - } - - return isunique -} -func UniqueAddress(networkName string, reverse bool) (net.IP, error) { - if servercfg.CacheEnabled() { - return UniqueAddressCache(networkName, reverse) - } - return UniqueAddressDB(networkName, reverse) -} - -func UniqueAddress6(networkName string, reverse bool) (net.IP, error) { - if servercfg.CacheEnabled() { - return UniqueAddress6Cache(networkName, reverse) - } - return UniqueAddress6DB(networkName, reverse) -} - -// UniqueAddress6DB - see if ipv6 address is unique -func UniqueAddress6DB(networkName string, reverse bool) (net.IP, error) { - add := net.IP{} - network := &schema.Network{Name: networkName} - err := network.Get(db.WithContext(context.TODO())) - if err != nil { - return add, err - } - if network.AddressRange6 == "" { - return add, fmt.Errorf("IPv6 not active on network %s", networkName) - } - - //ensure AddressRange is valid - if _, _, err := net.ParseCIDR(network.AddressRange6); err != nil { - return add, err - } - net6 := iplib.Net6FromStr(network.AddressRange6) - - newAddrs, err := net6.NextIP(net6.FirstAddress()) - if reverse { - newAddrs, err = net6.PreviousIP(net6.LastAddress()) - } - if err != nil { - return add, err - } - - for { - if IsIPUnique(networkName, newAddrs.String(), database.NODES_TABLE_NAME, true) && - IsIPUnique(networkName, newAddrs.String(), database.EXT_CLIENT_TABLE_NAME, true) { - return newAddrs, nil - } - if reverse { - newAddrs, err = net6.PreviousIP(newAddrs) - } else { - newAddrs, err = net6.NextIP(newAddrs) - } - if err != nil { - break - } - } - - return add, errors.New("ERROR: No unique IPv6 addresses available. Check network subnet") -} - -// UniqueAddress6Cache - see if ipv6 address is unique using cache -func UniqueAddress6Cache(networkName string, reverse bool) (net.IP, error) { - add := net.IP{} - network := &schema.Network{Name: networkName} - err := network.Get(db.WithContext(context.TODO())) - if err != nil { - return add, err - } - if network.AddressRange6 == "" { - return add, fmt.Errorf("IPv6 not active on network %s", networkName) - } - - //ensure AddressRange is valid - if _, _, err := net.ParseCIDR(network.AddressRange6); err != nil { - return add, err - } - net6 := iplib.Net6FromStr(network.AddressRange6) - - newAddrs, err := net6.NextIP(net6.FirstAddress()) - if reverse { - newAddrs, err = net6.PreviousIP(net6.LastAddress()) - } - if err != nil { - return add, err - } - - networkCacheMutex.RLock() - ipAllocated := allocatedIpMap[networkName] - for { - if _, ok := ipAllocated[newAddrs.String()]; !ok { - networkCacheMutex.RUnlock() - return newAddrs, nil - } - if reverse { - newAddrs, err = net6.PreviousIP(newAddrs) - } else { - newAddrs, err = net6.NextIP(newAddrs) - } - if err != nil { - break - } - } - networkCacheMutex.RUnlock() - - return add, errors.New("ERROR: No unique IPv6 addresses available. Check network subnet") -} - // IsNetworkNameUnique - checks to see if any other networks have the same name (id) func IsNetworkNameUnique(network *schema.Network) (bool, error) { _network := &schema.Network{ @@ -925,7 +570,7 @@ var NetworkHook models.HookFunc = func(params ...interface{}) error { continue } node.PendingDelete = true - node.Action = models.NODE_DELETE + node.Action = schema.NODE_DELETE DeleteNodesCh <- &node host := &schema.Host{ID: node.HostID} if err := host.Get(db.WithContext(context.TODO())); err == nil && len(host.Nodes) == 0 { @@ -944,7 +589,3 @@ func InitNetworkHooks() { Interval: time.Duration(GetServerSettings().CleanUpInterval) * time.Minute, } } - -// == Private == - -var addressLock = &sync.Mutex{} diff --git a/logic/nodes.go b/logic/nodes.go index 09785a5d4..a2188d9a3 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -5,152 +5,31 @@ import ( "encoding/json" "errors" "fmt" - "maps" "net" - "slices" "sort" "sync" "time" - validator "github.com/go-playground/validator/v10" "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" + dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" - "github.com/gravitl/netmaker/validation" "github.com/seancfoley/ipaddress-go/ipaddr" "golang.org/x/exp/slog" + "gorm.io/datatypes" + "gorm.io/gorm" ) var ( - nodeCacheMutex = &sync.RWMutex{} - nodeNetworkCacheMutex = &sync.RWMutex{} - nodesCacheMap = make(map[string]models.Node) - nodesNetworkCacheMap = make(map[string]map[string]models.Node) - DeleteNodesCh = make(chan *models.Node, 100) -) - -func getNodeFromCache(nodeID string) (node models.Node, ok bool) { - nodeCacheMutex.RLock() - node, ok = nodesCacheMap[nodeID] - if !ok { - nodeCacheMutex.RUnlock() - return - } - if node.Mutex != nil { - nodeCacheMutex.RUnlock() - return node, true - } - nodeCacheMutex.RUnlock() - - nodeCacheMutex.Lock() - defer nodeCacheMutex.Unlock() - node, ok = nodesCacheMap[nodeID] - if !ok { - return node, false - } - if node.Mutex == nil { - node.Mutex = &sync.Mutex{} - nodesCacheMap[nodeID] = node - } - return node, true -} -func getNodesFromCache() (nodes []models.Node) { - nodeCacheMutex.RLock() - nodes = make([]models.Node, 0, len(nodesCacheMap)) - for _, node := range nodesCacheMap { - nodes = append(nodes, node) - } - nodeCacheMutex.RUnlock() - return -} - -func deleteNodeFromCache(nodeID string) { - nodeCacheMutex.Lock() - delete(nodesCacheMap, nodeID) - nodeCacheMutex.Unlock() -} -func deleteNodeFromNetworkCache(nodeID string, network string) { - nodeNetworkCacheMutex.Lock() - delete(nodesNetworkCacheMap[network], nodeID) - nodeNetworkCacheMutex.Unlock() -} - -func storeNodeInNetworkCache(node models.Node, network string) { - if node.Mutex == nil { - node.Mutex = &sync.Mutex{} - } - nodeNetworkCacheMutex.Lock() - if nodesNetworkCacheMap[network] == nil { - nodesNetworkCacheMap[network] = make(map[string]models.Node) - } - nodesNetworkCacheMap[network][node.ID.String()] = node - nodeNetworkCacheMutex.Unlock() -} - -func storeNodeInCache(node models.Node) { - if node.Mutex == nil { - node.Mutex = &sync.Mutex{} - } - nodeCacheMutex.Lock() - nodesCacheMap[node.ID.String()] = node - nodeCacheMutex.Unlock() -} -func loadNodesIntoNetworkCache(nMap map[string]models.Node) { - nodeNetworkCacheMutex.Lock() - for _, v := range nMap { - if v.Mutex == nil { - v.Mutex = &sync.Mutex{} - } - network := v.Network - if nodesNetworkCacheMap[network] == nil { - nodesNetworkCacheMap[network] = make(map[string]models.Node) - } - nodesNetworkCacheMap[network][v.ID.String()] = v - } - nodeNetworkCacheMutex.Unlock() -} - -func loadNodesIntoCache(nMap map[string]models.Node) { - for id, v := range nMap { - if v.Mutex == nil { - v.Mutex = &sync.Mutex{} - nMap[id] = v - } - } - nodeCacheMutex.Lock() - nodesCacheMap = nMap - nodeCacheMutex.Unlock() -} -func ClearNodeCache() { - nodeCacheMutex.Lock() - nodesCacheMap = make(map[string]models.Node) - nodesNetworkCacheMap = make(map[string]map[string]models.Node) - nodeCacheMutex.Unlock() -} - -const ( - // RELAY_NODE_ERR - error to return if relay node is unfound - RELAY_NODE_ERR = "could not find relay for node" - // NodePurgeTime time to wait for node to response to a NODE_DELETE actions - NodePurgeTime = time.Second * 10 - // NodePurgeCheckTime is how often to check nodes for Pending Delete - NodePurgeCheckTime = time.Second * 30 + DeleteNodesCh = make(chan *models.Node, 100) ) // GetNetworkNodes - gets the nodes of a network func GetNetworkNodes(network string) ([]models.Node, error) { - - nodeNetworkCacheMutex.RLock() - if networkNodes, ok := nodesNetworkCacheMap[network]; ok { - nodes := slices.Collect(maps.Values(networkNodes)) - nodeNetworkCacheMutex.RUnlock() - return nodes, nil - } - nodeNetworkCacheMutex.RUnlock() allnodes, err := GetAllNodes() if err != nil { return []models.Node{}, err @@ -173,14 +52,6 @@ func GetHostNodes(host *schema.Host) []models.Node { // GetNetworkNodesMemory - gets all nodes belonging to a network from list in memory func GetNetworkNodesMemory(allNodes []models.Node, network string) []models.Node { - - nodeNetworkCacheMutex.RLock() - if networkNodes, ok := nodesNetworkCacheMap[network]; ok { - nodes := slices.Collect(maps.Values(networkNodes)) - nodeNetworkCacheMutex.RUnlock() - return nodes - } - nodeNetworkCacheMutex.RUnlock() var nodes = make([]models.Node, 0, len(allNodes)) for i := range allNodes { node := allNodes[i] @@ -199,22 +70,19 @@ var ( // UpdateNodeCheckin - buffers the checkin timestamp in memory when caching is enabled. // The actual DB write is deferred to FlushNodeCheckins (every 30s). // When caching is disabled (HA mode), writes directly to the DB. -func UpdateNodeCheckin(node *models.Node) error { - node.SetLastCheckIn() - node.EgressDetails = models.EgressDetails{} +func UpdateNodeCheckin(nodeID string) error { if servercfg.CacheEnabled() { pendingCheckinsMu.Lock() - pendingCheckins[node.ID.String()] = node.LastCheckIn + pendingCheckins[nodeID] = time.Now().UTC() pendingCheckinsMu.Unlock() - storeNodeInCache(*node) - storeNodeInNetworkCache(*node, node.Network) return nil } - data, err := json.Marshal(node) - if err != nil { - return err + + node := &schema.Node{ + ID: nodeID, + LastCheckIn: time.Now().UTC(), } - return database.Insert(node.ID.String(), string(data), database.NODES_TABLE_NAME) + return node.UpdateLastCheckIn(db.WithContext(context.TODO())) } // FlushNodeCheckins - writes all buffered check-in updates to the DB in one batch. @@ -229,19 +97,13 @@ func FlushNodeCheckins() { } var failed int for id, checkin := range batch { - node, err := GetNodeByID(id) - if err != nil { - failed++ - continue + node := &schema.Node{ + ID: id, + LastCheckIn: checkin, } - node.LastCheckIn = checkin - data, err := json.Marshal(node) + err := node.UpdateLastCheckIn(db.WithContext(context.TODO())) if err != nil { failed++ - continue - } - if err := database.Insert(id, string(data), database.NODES_TABLE_NAME); err != nil { - failed++ } } if failed > 0 { @@ -251,20 +113,39 @@ func FlushNodeCheckins() { // UpsertNode - updates node in the DB func UpsertNode(newNode *models.Node) error { - newNode.SetLastModified() - data, err := json.Marshal(newNode) - if err != nil { - return err + _node := ConvertModelsNodeToSchemaNode(newNode) + if _node.ID == "" { + return errors.New("error converting models.Node to schema.Node") } - newNode.EgressDetails = models.EgressDetails{} - err = database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME) + + err := _node.Upsert(db.WithContext(context.TODO())) if err != nil { return err } - if servercfg.CacheEnabled() { - storeNodeInCache(*newNode) - storeNodeInNetworkCache(*newNode, newNode.Network) + + if _node.PostureCheckLastEvaluationCycleID != "" { + evaluatedAt, err := time.Parse(time.RFC3339, _node.PostureCheckLastEvaluationCycleID) + if err != nil { + return err + } + + violations := make([]schema.PostureCheckViolation, 0, len(newNode.PostureChecksViolations)) + for _, violation := range newNode.PostureChecksViolations { + violations = append(violations, schema.PostureCheckViolation{ + EvaluationCycleID: _node.PostureCheckLastEvaluationCycleID, + CheckID: violation.CheckID, + NodeID: _node.ID, + Name: violation.Name, + Attribute: violation.Attribute, + Message: violation.Message, + Severity: violation.Severity, + EvaluatedAt: evaluatedAt, + }) + } + + return _node.UpsertViolations(db.WithContext(context.TODO()), violations) } + return nil } @@ -280,38 +161,17 @@ func UpdateNode(currentNode *models.Node, newNode *models.Node) error { } newNode.Fill(currentNode, servercfg.IsPro) - // check for un-settable server values - if err := ValidateNode(newNode, true); err != nil { - return err - } - if newNode.ID == currentNode.ID { - newNode.EgressDetails = models.EgressDetails{} - newNode.SetLastModified() if !currentNode.Connected && newNode.Connected { newNode.SetLastCheckIn() } - if data, err := json.Marshal(newNode); err != nil { - return err - } else { - err = database.Insert(newNode.ID.String(), string(data), database.NODES_TABLE_NAME) - if err != nil { - return err - } - if servercfg.CacheEnabled() { - storeNodeInCache(*newNode) - storeNodeInNetworkCache(*newNode, newNode.Network) - if newNode.Address.IP != nil && !newNode.Address.IP.Equal(currentNode.Address.IP) { - AddIpToAllocatedIpMap(newNode.Network, newNode.Address.IP) - RemoveIpFromAllocatedIpMap(currentNode.Network, currentNode.Address.IP.String()) - } - if newNode.Address6.IP != nil && !newNode.Address6.IP.Equal(currentNode.Address6.IP) { - AddIpToAllocatedIpMap(newNode.Network, newNode.Address6.IP) - RemoveIpFromAllocatedIpMap(currentNode.Network, currentNode.Address6.IP.String()) - } - } - return nil + + _node := ConvertModelsNodeToSchemaNode(newNode) + if _node.ID == "" { + return errors.New("error converting models.Node to schema.Node") } + + return _node.Update(db.WithContext(context.TODO())) } return fmt.Errorf("failed to update node %s, cannot change ID", currentNode.ID.String()) @@ -341,9 +201,6 @@ func cleanupNodeReferences(node *models.Node) { UpsertNode(&relayNode) } } - if node.FailedOverBy != uuid.Nil { - ResetFailedOverPeer(node) - } if len(node.AutoRelayedPeers) > 0 { ResetAutoRelayedPeer(node) } @@ -393,16 +250,21 @@ func cleanupNodeReferences(node *models.Node) { } func DeleteNode(node *models.Node, purge bool) error { - alreadyDeleted := node.PendingDelete || node.Action == models.NODE_DELETE - node.Action = models.NODE_DELETE + alreadyDeleted := node.PendingDelete || node.Action == schema.NODE_DELETE + node.Action = schema.NODE_DELETE if !purge && !alreadyDeleted { - newnode := *node - newnode.PendingDelete = true - if err := UpdateNode(node, &newnode); err != nil { + nodeID := node.ID + node := &schema.Node{ + ID: nodeID.String(), + Action: schema.NODE_DELETE, + PendingDelete: true, + } + err := node.MarkForDeletion(db.WithContext(context.TODO())) + if err != nil { return err } - newZombie <- node.ID + newZombie <- nodeID return nil } if alreadyDeleted { @@ -443,97 +305,32 @@ func GetNodeByHostRef(hostid, network string) (node models.Node, err error) { // DeleteNodeByID - deletes a node from database func DeleteNodeByID(node *models.Node) error { - var err error - var key = node.ID.String() - if err = database.DeleteRecord(database.NODES_TABLE_NAME, key); err != nil { - if !database.IsEmptyRecord(err) { - return err - } - } - if servercfg.CacheEnabled() { - deleteNodeFromCache(node.ID.String()) - deleteNodeFromNetworkCache(node.ID.String(), node.Network) + _node := &schema.Node{ + ID: node.ID.String(), } - if servercfg.IsDNSMode() { - SetDNS() + err := _node.Delete(db.WithContext(context.TODO())) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err } - // removeZombie <- node.ID + if err = DeleteMetrics(node.ID.String()); err != nil { logger.Log(1, "unable to remove metrics from DB for node", node.ID.String(), err.Error()) } - //recycle ip address - if servercfg.CacheEnabled() { - if node.Address.IP != nil { - RemoveIpFromAllocatedIpMap(node.Network, node.Address.IP.String()) - } - if node.Address6.IP != nil { - RemoveIpFromAllocatedIpMap(node.Network, node.Address6.IP.String()) - } - } - return nil } -// IsNodeIDUnique - checks if node id is unique -func IsNodeIDUnique(node *models.Node) (bool, error) { - _, err := database.FetchRecord(database.NODES_TABLE_NAME, node.ID.String()) - return database.IsEmptyRecord(err), err -} - -// ValidateNode - validates node values -func ValidateNode(node *models.Node, isUpdate bool) error { - v := validator.New() - _ = v.RegisterValidation("id_unique", func(fl validator.FieldLevel) bool { - if isUpdate { - return true - } - isFieldUnique, _ := IsNodeIDUnique(node) - return isFieldUnique - }) - _ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool { - err := (&schema.Network{Name: node.Network}).Get(db.WithContext(context.TODO())) - return err == nil - }) - _ = v.RegisterValidation("checkyesornoorunset", func(f1 validator.FieldLevel) bool { - return validation.CheckYesOrNoOrUnset(f1) - }) - err := v.Struct(node) - return err -} - // GetAllNodes - returns all nodes in the DB func GetAllNodes() ([]models.Node, error) { var nodes []models.Node - if servercfg.CacheEnabled() { - nodes = getNodesFromCache() - if len(nodes) != 0 { - return nodes, nil - } - } - nodesMap := make(map[string]models.Node) - if servercfg.CacheEnabled() { - defer loadNodesIntoCache(nodesMap) - defer loadNodesIntoNetworkCache(nodesMap) - } - collection, err := database.FetchRecords(database.NODES_TABLE_NAME) + _nodes, err := (&schema.Node{}).ListAll(db.WithContext(context.TODO()), dbtypes.WithAllPreloads()) if err != nil { - if database.IsEmptyRecord(err) { - return []models.Node{}, nil - } - return []models.Node{}, err + return nil, err } - for _, value := range collection { - var node models.Node - // ignore legacy nodes in database - if err := json.Unmarshal([]byte(value), &node); err != nil { - logger.Log(3, "legacy node detected: ", err.Error()) - continue - } - ensureNodeMutex(&node) - // add node to our array - nodes = append(nodes, node) - nodesMap[node.ID.String()] = node + for _, _node := range _nodes { + node := ConvertSchemaNodeToModelsNode(&_node) + ensureNodeMutex(node) + nodes = append(nodes, *node) } return nodes, nil @@ -622,26 +419,18 @@ func GetRecordKey(id string, network string) (string, error) { return id + "###" + network, nil } -func GetNodeByID(uuid string) (models.Node, error) { - if servercfg.CacheEnabled() { - if node, ok := getNodeFromCache(uuid); ok { - return node, nil - } +func GetNodeByID(nodeID string) (models.Node, error) { + _node := &schema.Node{ + ID: nodeID, } - var record, err = database.FetchRecord(database.NODES_TABLE_NAME, uuid) + err := _node.Get(db.WithContext(context.TODO()), dbtypes.WithAllPreloads()) if err != nil { return models.Node{}, err } - var node models.Node - if err = json.Unmarshal([]byte(record), &node); err != nil { - return models.Node{}, err - } - ensureNodeMutex(&node) - if servercfg.CacheEnabled() { - storeNodeInCache(node) - storeNodeInNetworkCache(node, node.Network) - } - return node, nil + + node := ConvertSchemaNodeToModelsNode(_node) + ensureNodeMutex(node) + return *node, nil } // GetDeletedNodeByID - get a deleted node @@ -663,16 +452,6 @@ func GetDeletedNodeByID(uuid string) (models.Node, error) { return node, nil } -// FindRelay - returns the node that is the relay for a relayed node -func FindRelay(node *models.Node) *models.Node { - relay, err := GetNodeByID(node.RelayedBy) - if err != nil { - logger.Log(0, "FindRelay: "+err.Error()) - return nil - } - return &relay -} - // GetAllNodesAPI - get all nodes for api usage func GetAllNodesAPI(nodes []models.Node) []models.ApiNode { apiNodes := []models.ApiNode{} @@ -735,113 +514,22 @@ func DeleteExpiredNodes(ctx context.Context) { ticker.Stop() return case <-ticker.C: - allnodes, err := GetAllNodes() + nodes, err := (&schema.Node{}).ListAll(db.WithContext(ctx)) if err != nil { slog.Error("failed to retrieve all nodes", "error", err.Error()) return } - for _, node := range allnodes { - node := node + for _, node := range nodes { + node := ConvertSchemaNodeToModelsNode(&node) if time.Now().After(node.ExpirationDateTime) { - DeleteNodesCh <- &node - slog.Info("deleting expired node", "nodeid", node.ID.String()) + DeleteNodesCh <- node + slog.Info("deleting expired node", "nodeid", node.ID) } } } } } -// createNode - creates a node in database -func createNode(node *models.Node) error { - // lock because we need unique IPs and having it concurrent makes parallel calls result in same "unique" IPs - addressLock.Lock() - defer addressLock.Unlock() - - host := &schema.Host{ - ID: node.HostID, - } - err := host.Get(db.WithContext(context.TODO())) - if err != nil { - return err - } - - SetNodeDefaults(node, true) - parentNetwork := &schema.Network{Name: node.Network} - err = parentNetwork.Get(db.WithContext(context.TODO())) - if err != nil { - return err - } - if node.Address.IP == nil { - if parentNetwork.AddressRange != "" { - if node.Address.IP, err = UniqueAddress(node.Network, false); err != nil { - return err - } - _, cidr, err := net.ParseCIDR(parentNetwork.AddressRange) - if err != nil { - return err - } - node.Address.Mask = net.CIDRMask(cidr.Mask.Size()) - } - } else if !IsIPUnique(node.Network, node.Address.String(), database.NODES_TABLE_NAME, false) { - return fmt.Errorf("invalid address: ipv4 %s is not unique", node.Address.String()) - } - if node.Address6.IP == nil { - if parentNetwork.AddressRange6 != "" { - if node.Address6.IP, err = UniqueAddress6(node.Network, false); err != nil { - return err - } - _, cidr, err := net.ParseCIDR(parentNetwork.AddressRange6) - if err != nil { - return err - } - node.Address6.Mask = net.CIDRMask(cidr.Mask.Size()) - } - } else if !IsIPUnique(node.Network, node.Address6.String(), database.NODES_TABLE_NAME, true) { - return fmt.Errorf("invalid address: ipv6 %s is not unique", node.Address6.String()) - } - node.ID = uuid.New() - //Create a JWT for the node - tokenString, _ := CreateJWT(node.ID.String(), host.MacAddress.String(), node.Network) - if tokenString == "" { - //logic.ReturnErrorResponse(w, r, errorResponse) - return err - } - err = ValidateNode(node, false) - if err != nil { - return err - } - CheckZombies(node) - node.SetLastCheckIn() - nodebytes, err := json.Marshal(&node) - if err != nil { - return err - } - err = database.Insert(node.ID.String(), string(nodebytes), database.NODES_TABLE_NAME) - if err != nil { - return err - } - if servercfg.CacheEnabled() { - storeNodeInCache(*node) - storeNodeInNetworkCache(*node, node.Network) - if node.Address.IP != nil { - AddIpToAllocatedIpMap(node.Network, node.Address.IP) - } - if node.Address6.IP != nil { - AddIpToAllocatedIpMap(node.Network, node.Address6.IP) - } - } - - if err = UpdateMetrics(node.ID.String(), &models.Metrics{Connectivity: make(map[string]models.Metric)}); err != nil { - logger.Log(1, "failed to initialize metrics for node", node.ID.String(), err.Error()) - } - - SetNetworkNodesLastModified(node.Network) - if servercfg.IsDNSMode() { - err = SetDNS() - } - return err -} - // SortApiNodes - Sorts slice of ApiNodes by their ID alphabetically with numbers first func SortApiNodes(unsortedNodes []models.ApiNode) { sort.Slice(unsortedNodes, func(i, j int) bool { @@ -862,24 +550,6 @@ func ValidateParams(nodeid, netid string) (models.Node, error) { return node, nil } -func ValidateNodeIp(currentNode *models.Node, newNode *models.ApiNode) error { - - if currentNode.Address.IP != nil && currentNode.Address.String() != newNode.Address { - if !IsIPUnique(newNode.Network, newNode.Address, database.NODES_TABLE_NAME, false) || - !IsIPUnique(newNode.Network, newNode.Address, database.EXT_CLIENT_TABLE_NAME, false) { - return errors.New("ip specified is already allocated: " + newNode.Address) - } - } - if currentNode.Address6.IP != nil && currentNode.Address6.String() != newNode.Address6 { - if !IsIPUnique(newNode.Network, newNode.Address6, database.NODES_TABLE_NAME, false) || - !IsIPUnique(newNode.Network, newNode.Address6, database.EXT_CLIENT_TABLE_NAME, false) { - return errors.New("ip specified is already allocated: " + newNode.Address6) - } - } - - return nil -} - func ValidateEgressRange(netID string, ranges []string) error { network := &schema.Network{Name: netID} err := network.Get(db.WithContext(context.TODO())) @@ -914,17 +584,216 @@ func ContainsCIDR(net1, net2 string) bool { return one.Contains(two) || two.Contains(one) } -// GetAllFailOvers - gets all the nodes that are failovers -func GetAllFailOvers() ([]models.Node, error) { - nodes, err := GetAllNodes() +func ConvertSchemaNodeToApiNode(_node *schema.Node) *models.ApiNode { + return ConvertSchemaNodeToModelsNode(_node).ConvertToAPINode() +} + +func ConvertSchemaNodeToModelsNode(_node *schema.Node) *models.Node { + nodeID, err := uuid.Parse(_node.ID) if err != nil { - return nil, err + return &models.Node{} } - igs := make([]models.Node, 0) - for _, node := range nodes { - if node.IsFailOver { - igs = append(igs, node) + + var nodeAddr, nodeAddr6 net.IPNet + if _node.Address != "" { + ip, cidr, err := net.ParseCIDR(_node.Address) + if err != nil { + return &models.Node{} + } + + cidr.IP = ip + nodeAddr = *cidr + } + + if _node.Address6 != "" { + ip6, cidr, err := net.ParseCIDR(_node.Address6) + if err != nil { + return &models.Node{} + } + + cidr.IP = ip6 + nodeAddr6 = *cidr + } + + hostID, err := uuid.Parse(_node.HostID) + if err != nil { + return &models.Node{} + } + + if _node.Host == nil { + _node.Host = &schema.Host{ + ID: hostID, + } + err = _node.Host.Get(db.WithContext(context.TODO())) + if err != nil { + return &models.Node{} + } + } + + var netAddrRange, netAddr6Range net.IPNet + if _node.Network == nil { + _node.Network = &schema.Network{ + ID: _node.Network.ID, + } + err = _node.Network.Get(db.WithContext(context.TODO())) + if err != nil { + return &models.Node{} + } + } + + var violations []models.Violation + _violations, err := _node.ListViolations(db.WithContext(context.TODO())) + if err == nil { + for _, _violation := range _violations { + violations = append(violations, models.Violation{ + CheckID: _violation.CheckID, + Name: _violation.Name, + Attribute: _violation.Attribute, + Message: _violation.Message, + Severity: _violation.Severity, + }) + } + } + + node := &models.Node{ + CommonNode: models.CommonNode{ + ID: nodeID, + HostID: hostID, + Network: _node.Network.Name, + NetworkRange: netAddrRange, + NetworkRange6: netAddr6Range, + Server: servercfg.GetServer(), + Connected: _node.Connected, + Address: nodeAddr, + Address6: nodeAddr6, + Action: _node.Action, + IsRelay: _node.IsGateway, + IsGw: _node.IsGateway, + AutoAssignGateway: _node.AutoAssignGateway, + }, + PendingDelete: _node.PendingDelete, + LastModified: _node.UpdatedAt, + LastCheckIn: _node.LastCheckIn, + ExpirationDateTime: _node.ExpirationDateTime, + Metadata: _node.Metadata, + IsAutoRelay: _node.IsAutoRelay, + AutoRelayedPeers: _node.AutoRelayedPeers.Data(), + IsInternetGateway: _node.IsInternetGateway, + Tags: make(map[models.TagID]struct{}), + Status: _node.Status, + PostureChecksViolations: violations, + PostureCheckVolationSeverityLevel: _node.PostureCheckSeverity, + LastEvaluatedAt: _node.UpdatedAt, + Location: _node.Host.Location, + CountryCode: _node.Host.CountryCode, + } + + if _node.IsGateway { + node.IngressGatewayRange = _node.Network.AddressRange + node.IngressGatewayRange6 = _node.Network.AddressRange6 + node.IngressPersistentKeepalive = int32(_node.Host.PersistentKeepalive.Seconds()) + node.IngressMTU = int32(_node.Host.MTU) + node.RelayedNodes = make([]string, 0, len(_node.RelayedClients)) + node.InetNodeReq = models.InetNodeReq{ + InetNodeClientIDs: make([]string, 0, len(_node.RelayedIGWClients)), + } + + for relayedClientID := range _node.RelayedClients { + node.RelayedNodes = append(node.RelayedNodes, relayedClientID) + } + + for relayedIGWClientID := range _node.RelayedIGWClients { + node.InetNodeReq.InetNodeClientIDs = append(node.InetNodeReq.InetNodeClientIDs, relayedIGWClientID) } } - return igs, nil + + if _node.RelayingNodeID != nil { + node.IsRelayed = true + node.RelayedBy = *_node.RelayingNodeID + + if _node.IsIGWClient { + node.InternetGwID = *_node.RelayingNodeID + } + } + + for tagID := range _node.Tags { + node.Tags[models.TagID(tagID)] = struct{}{} + } + + return node +} + +func ConvertModelsNodeToSchemaNode(node *models.Node) *schema.Node { + var address, address6 string + if node.Address.IP != nil { + address = node.Address.String() + } + + if node.Address6.IP != nil { + address6 = node.Address6.String() + } + + host := &schema.Host{ + ID: node.HostID, + } + err := host.Get(db.WithContext(context.TODO())) + if err != nil { + return &schema.Node{} + } + + network := &schema.Network{ + Name: node.Network, + } + err = network.Get(db.WithContext(context.TODO())) + if err != nil { + return &schema.Node{} + } + + relayedClients := make(datatypes.JSONMap) + for _, relayedNodeID := range node.RelayedNodes { + relayedClients[relayedNodeID] = struct{}{} + } + + relayedIGWClients := make(datatypes.JSONMap) + for _, inetNodeClientID := range node.InetNodeReq.InetNodeClientIDs { + relayedIGWClients[inetNodeClientID] = struct{}{} + } + + relayedBy := node.RelayedBy + + tags := make(datatypes.JSONMap) + for tagID := range node.Tags { + tags[tagID.String()] = struct{}{} + } + + return &schema.Node{ + ID: node.ID.String(), + HostID: host.ID.String(), + Host: host, + NetworkID: network.ID, + Network: network, + Address: address, + Address6: address6, + Connected: node.Connected, + Action: node.Action, + Status: node.Status, + PendingDelete: node.PendingDelete, + AutoAssignGateway: node.AutoAssignGateway, + IsGateway: node.IsGw || node.IsRelay || node.IsIngressGateway, + IsAutoRelay: node.IsAutoRelay, + IsInternetGateway: node.IsGw && node.IsInternetGateway, + RelayedClients: relayedClients, + RelayedIGWClients: relayedIGWClients, + RelayingNodeID: &relayedBy, + IsIGWClient: node.IsRelayed && node.InternetGwID != "", + AutoRelayedPeers: datatypes.NewJSONType(node.AutoRelayedPeers), + Tags: tags, + PostureCheckSeverity: node.PostureCheckVolationSeverityLevel, + PostureCheckLastEvaluationCycleID: node.LastEvaluatedAt.Format(time.RFC3339), + Metadata: node.Metadata, + LastCheckIn: node.LastCheckIn, + ExpirationDateTime: node.ExpirationDateTime, + CreatedAt: node.LastModified, + UpdatedAt: node.LastModified, + } } diff --git a/logic/peers.go b/logic/peers.go index ecd4184bf..5201b3daf 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -21,29 +21,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -var ( - // ResetFailOver - function to reset failOvered peers on this node - ResetFailOver = func(failOverNode *models.Node) error { - return nil - } - // ResetFailedOverPeer - removes failed over node from network peers - ResetFailedOverPeer = func(failedOverNode *models.Node) error { - return nil - } - // FailOverExists - check if failover node existed or not - FailOverExists = func(network string) (failOverNode models.Node, exists bool) { - return failOverNode, exists - } - // GetFailOverPeerIps - gets failover peerips - GetFailOverPeerIps = func(peer, node *models.Node) []net.IPNet { - return []net.IPNet{} - } - // CreateFailOver - creates failover in a network - CreateFailOver = func(node models.Node) error { - return nil - } -) - var ( // ResetAutoRelay - function to reset autorelayed peers on this node ResetAutoRelay = func(autoRelayNode *models.Node) error { @@ -176,7 +153,7 @@ func computeHostPeerInfo(host *schema.Host, allNodes []models.Node, serverInfo m continue } - if !node.Connected || node.PendingDelete || node.Action == models.NODE_DELETE { + if !node.Connected || node.PendingDelete || node.Action == schema.NODE_DELETE { continue } networkPeersInfo := make(models.PeerMap) @@ -204,7 +181,7 @@ func computeHostPeerInfo(host *schema.Host, allNodes []models.Node, serverInfo m } else { allowedToComm = IsPeerAllowed(node, peer, false) } - if peer.Action != models.NODE_DELETE && + if peer.Action != schema.NODE_DELETE && !peer.PendingDelete && peer.Connected && (allowedToComm) { @@ -306,7 +283,7 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N continue } - if !node.Connected || node.PendingDelete || node.Action == models.NODE_DELETE || + if !node.Connected || node.PendingDelete || node.Action == schema.NODE_DELETE || (!node.LastCheckIn.IsZero() && time.Since(node.LastCheckIn) > time.Hour) { if deletedNode == nil || deletedNode.ID != node.ID { continue @@ -400,7 +377,7 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N if node.Mutex != nil { node.Mutex.Lock() } - _, isFailOverPeer := node.FailOverPeers[peer.ID.String()] + peerAutoRelayID, isAutoRelayPeer := node.AutoRelayedPeers[peer.ID.String()] if node.Mutex != nil { node.Mutex.Unlock() @@ -408,19 +385,6 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N if peer.EgressDetails.IsEgressGateway { peerKey := peerHost.PublicKey.String() - if isFailOverPeer && peer.FailedOverBy.String() != node.ID.String() { - // get relay host - failOverNode, err := GetNodeByID(peer.FailedOverBy.String()) - if err == nil { - relayHost := &schema.Host{ - ID: failOverNode.HostID, - } - err := relayHost.Get(db.WithContext(context.TODO())) - if err == nil { - peerKey = relayHost.PublicKey.String() - } - } - } if isAutoRelayPeer && peerAutoRelayID != node.ID.String() { // get relay host autoRelayNode, err := GetNodeByID(peerAutoRelayID) @@ -470,7 +434,7 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N } if (node.IsRelayed && node.RelayedBy != peer.ID.String()) || - (peer.IsRelayed && peer.RelayedBy != node.ID.String()) || isFailOverPeer || isAutoRelayPeer { + (peer.IsRelayed && peer.RelayedBy != node.ID.String()) || isAutoRelayPeer { // if node is relayed and peer is not the relay, set remove to true if _, ok := peerIndexMap[peerHost.PublicKey.String()]; ok { continue @@ -522,9 +486,6 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N // don't set endpoint on relayed peer peerEndpoint = nil } - if isFailOverPeer && peer.FailedOverBy == node.ID && !peer.IsStatic { - peerEndpoint = nil - } if isAutoRelayPeer && peerAutoRelayID == node.ID.String() && !peer.IsStatic { peerEndpoint = nil } @@ -538,7 +499,7 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N peerConfig.Endpoint.Port = peerHost.ListenPort } - if peer.Action != models.NODE_DELETE && + if peer.Action != schema.NODE_DELETE && !peer.PendingDelete && peer.Connected && (allowedToComm) && @@ -926,9 +887,6 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet { if peer.IsRelay { allowedips = append(allowedips, RelayedAllowedIPs(peer, node)...) } - if peer.IsFailOver { - allowedips = append(allowedips, GetFailOverPeerIps(peer, node)...) - } if peer.IsAutoRelay { allowedips = append(allowedips, GetAutoRelayPeerIps(peer, node)...) } diff --git a/logic/relay.go b/logic/relay.go index 0d29afe79..a7ccabba9 100644 --- a/logic/relay.go +++ b/logic/relay.go @@ -6,28 +6,12 @@ import ( "fmt" "net" - "github.com/google/uuid" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" ) -// GetRelays - gets all the nodes that are relays -func GetRelays() ([]models.Node, error) { - nodes, err := GetAllNodes() - if err != nil { - return nil, err - } - relays := make([]models.Node, 0) - for _, node := range nodes { - if node.IsRelay { - relays = append(relays, node) - } - } - return relays, nil -} - // CreateRelay - creates a relay func CreateRelay(relay models.RelayRequest) ([]models.Node, models.Node, error) { var returnnodes []models.Node @@ -94,22 +78,6 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.N return returnnodes } -// func GetRelayedNodes(relayNode *models.Node) (models.Node, error) { -// var returnnodes []models.Node -// networkNodes, err := GetNetworkNodes(relayNode.Network) -// if err != nil { -// return returnnodes, err -// } -// for _, node := range networkNodes { -// for _, addr := range relayNode.RelayAddrs { -// if addr == node.Address.IP.String() || addr == node.Address6.IP.String() { -// returnnodes = append(returnnodes, node) -// } -// } -// } -// return returnnodes, nil -// } - // ValidateRelay - checks if relay is valid func ValidateRelay(relay models.RelayRequest, update bool) error { var err error @@ -121,14 +89,11 @@ func ValidateRelay(relay models.RelayRequest, update bool) error { if !update && node.IsRelay { return errors.New("node is already acting as a relay") } - eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO())) - acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network)) for _, relayedNodeID := range relay.RelayedNodes { relayedNode, err := GetNodeByID(relayedNodeID) if err != nil { return err } - GetNodeEgressInfo(&relayedNode, eli, acls) if relayedNode.IsIngressGateway { return errors.New("cannot relay an ingress gateway (" + relayedNodeID + ")") } @@ -138,12 +103,9 @@ func ValidateRelay(relay models.RelayRequest, update bool) error { if relayedNode.InternetGwID != "" && relayedNode.InternetGwID != relay.NodeID { return errors.New("cannot relay an internet client (" + relayedNodeID + ")") } - if relayedNode.IsFailOver || relayedNode.IsAutoRelay { + if relayedNode.IsAutoRelay { return errors.New("cannot relay a auto relay node (" + relayedNodeID + ")") } - if relayedNode.FailedOverBy != uuid.Nil { - ResetFailedOverPeer(&relayedNode) - } if len(relayedNode.AutoRelayedPeers) > 0 { ResetAutoRelayedPeer(&relayedNode) } @@ -179,7 +141,6 @@ func UpdateRelayed(currentNode, newNode *models.Node) { if len(updatenodes) > 0 { for _, relayedNode := range updatenodes { node := relayedNode - ResetFailedOverPeer(&node) ResetAutoRelayedPeer(&node) } } diff --git a/logic/status.go b/logic/status.go index 76b170fc5..7fbd16d9b 100644 --- a/logic/status.go +++ b/logic/status.go @@ -4,6 +4,7 @@ import ( "time" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" ) var GetNodeStatus = getNodeCheckInStatus @@ -12,19 +13,29 @@ func getNodeCheckInStatus(node *models.Node, t bool) { // On CE check only last check-in time if node.IsStatic { if !node.StaticNode.Enabled { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } - node.Status = models.OnlineSt + node.Status = schema.OnlineSt return } if !node.Connected { - node.Status = models.Disconnected + node.Status = schema.Disconnected return } if time.Since(node.LastCheckIn) > time.Minute*10 { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } - node.Status = models.OnlineSt + node.Status = schema.OnlineSt +} + +func GetNodeCheckInStatus(node *schema.Node) schema.NodeStatus { + if !node.Connected { + return schema.Disconnected + } + if time.Since(node.LastCheckIn) > time.Minute*10 { + return schema.OfflineSt + } + return schema.OnlineSt } diff --git a/logic/telemetry.go b/logic/telemetry.go index 83c82f29e..0c4df43ae 100644 --- a/logic/telemetry.go +++ b/logic/telemetry.go @@ -8,6 +8,7 @@ import ( "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" + dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/schema" @@ -100,7 +101,7 @@ func FetchTelemetryData() telemetryData { data.Hosts, _ = (&schema.Host{}).Count(db.WithContext(context.TODO())) data.Version = servercfg.GetVersion() data.Servers = getServerCount() - nodes, _ := GetAllNodes() + nodes, _ := (&schema.Node{}).ListAll(db.WithContext(context.TODO()), dbtypes.WithPreloads("Host")) data.Nodes = len(nodes) data.Count = getClientCount(nodes) endDate, _ := GetTrialEndDate() @@ -143,17 +144,10 @@ func setTelemetryTimestamp(telRecord *models.Telemetry) error { } // getClientCount - returns counts of nodes with various OS types and conditions -func getClientCount(nodes []models.Node) clientCount { +func getClientCount(nodes []schema.Node) clientCount { var count clientCount for _, node := range nodes { - host := &schema.Host{ - ID: node.HostID, - } - err := host.Get(db.WithContext(context.TODO())) - if err != nil { - continue - } - switch host.OS { + switch node.Host.OS { case "darwin": count.MacOS += 1 case "windows": diff --git a/logic/usage.go b/logic/usage.go index bcd1b5d4b..4f5099f4c 100644 --- a/logic/usage.go +++ b/logic/usage.go @@ -10,11 +10,11 @@ import ( func GetCurrentServerUsage() (limits models.Usage) { limits.SetDefaults() - hosts, hErr := GetAllHostsWithStatus(models.OnlineSt) + hosts, hErr := GetAllHostsWithStatus(schema.OnlineSt) if hErr == nil { limits.Hosts = len(hosts) } - clients, cErr := GetAllExtClientsWithStatus(models.OnlineSt) + clients, cErr := GetAllExtClientsWithStatus(schema.OnlineSt) if cErr == nil { limits.Clients = len(clients) } diff --git a/logic/zombie.go b/logic/zombie.go index 09389f9e1..fbe2503ea 100644 --- a/logic/zombie.go +++ b/logic/zombie.go @@ -7,7 +7,6 @@ import ( "github.com/google/uuid" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" ) @@ -28,18 +27,18 @@ var ( // CheckZombies - checks if new node has same hostid as existing node // if so, existing node is added to zombie node quarantine list // also cleans up nodes past their expiration date -func CheckZombies(newnode *models.Node) { - nodes, err := GetNetworkNodes(newnode.Network) +func CheckZombies(_node *schema.Node) { + nodes, err := GetNetworkNodes(_node.Network.Name) if err != nil { - logger.Log(1, "Failed to retrieve network nodes", newnode.Network, err.Error()) + logger.Log(1, "Failed to retrieve network nodes", _node.Network.Name, err.Error()) return } for _, node := range nodes { - if node.ID == newnode.ID { + if node.ID.String() == _node.ID { //skip self continue } - if node.HostID == newnode.HostID { + if node.HostID.String() == _node.HostID { logger.Log(0, "adding ", node.ID.String(), " to zombie list") newZombie <- node.ID } @@ -110,7 +109,7 @@ func ManageZombies(ctx context.Context) { continue } node.PendingDelete = true - node.Action = models.NODE_DELETE + node.Action = schema.NODE_DELETE DeleteNodesCh <- &node logger.Log(1, "deleting zombie node", node.ID.String()) zombies = append(zombies[:i], zombies[i+1:]...) @@ -144,7 +143,7 @@ func checkPendingRemovalNodes() { nodes, _ := GetAllNodes() for _, node := range nodes { node := node - pendingDelete := node.PendingDelete || node.Action == models.NODE_DELETE + pendingDelete := node.PendingDelete || node.Action == schema.NODE_DELETE if pendingDelete { DeleteNode(&node, true) DeleteNodesCh <- &node diff --git a/main.go b/main.go index 08fbbbc9e..6f8c5123d 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,8 @@ import ( ch "github.com/gravitl/netmaker/clickhouse" "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/orchestrator" + "github.com/gravitl/netmaker/orchestrator/extensions" "github.com/gravitl/netmaker/schema" "github.com/google/uuid" @@ -49,6 +51,10 @@ var version = "v1.5.1" // Start DB Connection and start API Request Handler func main() { + // Initializes repository with a CE extensions factory as the default. If built with 'ee' tag, the EE init() + // will have already registered the Pro factory and this call will be a no-op. + orchestrator.InitializeRepository(extensions.NewCEFactory()) + absoluteConfigPath := flag.String("c", "", "absolute path to configuration file") flag.Parse() setVerbosity() @@ -56,8 +62,6 @@ func main() { servercfg.SetVersion(version) fmt.Println(models.RetrieveLogo()) // print the logo initialize() // initial db and acls - logic.SetAllocatedIpMap() - defer logic.ClearAllocatedIpMap() setGarbageCollection() defer db.CloseDB() defer database.CloseDB() @@ -138,7 +142,6 @@ func initialize() { // Client Mode Prereq Check initializeUUID() //initialize cache - _, _ = logic.GetAllNodes() _, _ = logic.GetAllExtClients() _ = logic.ListAcls() _, _ = logic.GetAllEnrollmentKeys() @@ -208,7 +211,7 @@ func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) { continue } node := nodeUpdate - node.Action = models.NODE_DELETE + node.Action = schema.NODE_DELETE node.PendingDelete = true if err := mq.NodeUpdate(node); err != nil { logger.Log( @@ -222,7 +225,7 @@ func runMessageQueue(wg *sync.WaitGroup, ctx context.Context) { slog.Error( "error deleting expired node", "nodeid", - node.ID.String(), + node.ID, "error", err.Error(), ) diff --git a/migrate/migrate_v1_5_2.go b/migrate/migrate_v1_5_2.go index 95fbeb6cd..7d2c3030b 100644 --- a/migrate/migrate_v1_5_2.go +++ b/migrate/migrate_v1_5_2.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "time" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" @@ -18,7 +19,12 @@ func migrateV1_5_2(ctx context.Context) error { return err } - return migrateUserInvites(ctx) + err = migrateUserInvites(ctx) + if err != nil { + return err + } + + return migrateNodes(ctx) } func migratePendingUsers(ctx context.Context) error { @@ -72,13 +78,125 @@ func migrateUserInvites(ctx context.Context) error { UserGroups: datatypes.NewJSONType(userInvite.UserGroups), } - logger.Log(4, fmt.Sprintf("migrating user invite %s", _userInvite.InviteCode)) + logger.Log(4, fmt.Sprintf("migrating user invite %s/%s", _userInvite.InviteCode, _userInvite.Email)) err = _userInvite.Create(ctx) if err != nil { - logger.Log(4, fmt.Sprintf("migrating user invite (%s/%s) failed: %v", _userInvite.InviteCode, _userInvite.Email, err)) + logger.Log(4, fmt.Sprintf("migrating user invite %s/%s failed: %v", _userInvite.InviteCode, _userInvite.Email, err)) + return err + } + } + + return nil +} + +func migrateNodes(ctx context.Context) error { + records, err := fetchAll(ctx, database.NODES_TABLE_NAME) + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + for _, record := range records { + var node models.Node + err = json.Unmarshal([]byte(record), &node) + if err != nil { + return err + } + + var address, address6 string + if node.Address.IP != nil { + address = node.Address.String() + } + + if node.Address6.IP != nil { + address6 = node.Address6.String() + } + + network := &schema.Network{ + Name: node.Network, + } + err = network.Get(ctx) + if err != nil { return err } + + relayedClients := make(datatypes.JSONMap) + for _, relayedNodeID := range node.RelayedNodes { + relayedClients[relayedNodeID] = struct{}{} + } + + relayedIGWClients := make(datatypes.JSONMap) + for _, inetNodeClientID := range node.InetNodeReq.InetNodeClientIDs { + relayedIGWClients[inetNodeClientID] = struct{}{} + } + + relayedBy := node.RelayedBy + + tags := make(datatypes.JSONMap) + for tagID := range node.Tags { + tags[tagID.String()] = struct{}{} + } + + _node := &schema.Node{ + ID: node.ID.String(), + HostID: node.HostID.String(), + NetworkID: network.ID, + Address: address, + Address6: address6, + Connected: node.Connected, + Action: node.Action, + Status: node.Status, + PendingDelete: node.PendingDelete, + AutoAssignGateway: node.AutoAssignGateway, + IsGateway: node.IsGw || node.IsRelay || node.IsIngressGateway, + IsAutoRelay: node.IsAutoRelay, + IsInternetGateway: node.IsGw && node.IsInternetGateway, + RelayedClients: relayedClients, + RelayedIGWClients: relayedIGWClients, + RelayingNodeID: &relayedBy, + IsIGWClient: node.IsRelayed && node.InternetGwID != "", + AutoRelayedPeers: datatypes.NewJSONType(node.AutoRelayedPeers), + Tags: tags, + PostureCheckSeverity: node.PostureCheckVolationSeverityLevel, + PostureCheckLastEvaluationCycleID: node.LastEvaluatedAt.Format(time.RFC3339), + Metadata: node.Metadata, + LastCheckIn: node.LastCheckIn, + ExpirationDateTime: node.ExpirationDateTime, + CreatedAt: node.LastModified, + UpdatedAt: node.LastModified, + } + + logger.Log(4, fmt.Sprintf("migrating node %s", _node.ID)) + + err = _node.Create(ctx) + if err != nil { + logger.Log(4, fmt.Sprintf("migrating node %s failed: %v", _node.ID, err)) + return err + } + + if !node.LastEvaluatedAt.IsZero() { + violations := make([]schema.PostureCheckViolation, 0, len(node.PostureChecksViolations)) + for _, violation := range node.PostureChecksViolations { + violations = append(violations, schema.PostureCheckViolation{ + EvaluationCycleID: node.LastEvaluatedAt.Format(time.RFC3339), + CheckID: violation.CheckID, + NodeID: _node.ID, + Name: violation.Name, + Attribute: violation.Attribute, + Message: violation.Message, + Severity: violation.Severity, + EvaluatedAt: node.LastEvaluatedAt, + }) + } + + logger.Log(4, fmt.Sprintf("migrating node %s violations", _node.ID)) + + err = _node.UpsertViolations(ctx, violations) + if err != nil { + logger.Log(4, fmt.Sprintf("migrating node %s violations failed: %v", _node.ID, err)) + return err + } + } } return nil diff --git a/models/api_node.go b/models/api_node.go index 76923d3b7..051c3f2d0 100644 --- a/models/api_node.go +++ b/models/api_node.go @@ -10,10 +10,10 @@ import ( ) type ApiNodeStatus struct { - ID string `json:"id"` - IsStatic bool `json:"is_static"` - IsUserNode bool `json:"is_user_node"` - Status NodeStatus `json:"status"` + ID string `json:"id"` + IsStatic bool `json:"is_static"` + IsUserNode bool `json:"is_user_node"` + Status schema.NodeStatus `json:"status"` } // ApiNode is a stripped down Node DTO that exposes only required fields to external systems @@ -66,7 +66,7 @@ type ApiNode struct { IsStatic bool `json:"is_static"` IsUserNode bool `json:"is_user_node"` StaticNode ExtClient `json:"static_node"` - Status NodeStatus `json:"status"` + Status schema.NodeStatus `json:"status"` Location string `json:"location"` Country string `json:"country"` PostureChecksViolations []Violation `json:"posture_check_violations"` diff --git a/models/migrate.go b/models/migrate.go deleted file mode 100644 index 978354cb4..000000000 --- a/models/migrate.go +++ /dev/null @@ -1,9 +0,0 @@ -package models - -// MigrationData struct needed to create new v0.18.0 node from v.0.17.X node -type MigrationData struct { - HostName string - Password string - OS string - LegacyNodes []LegacyNode -} diff --git a/models/mqtt.go b/models/mqtt.go index 372654988..67e297a18 100644 --- a/models/mqtt.go +++ b/models/mqtt.go @@ -151,11 +151,6 @@ type FwUpdate struct { AclRules map[string]AclRule `json:"acl_rules"` } -// FailOverMeReq - struct for failover req -type FailOverMeReq struct { - NodeID string `json:"node_id"` -} - // AutoRelayMeReq - struct for autorelay req type AutoRelayMeReq struct { NodeID string `json:"node_id"` diff --git a/models/node.go b/models/node.go index 8899c534f..ac7c9be54 100644 --- a/models/node.go +++ b/models/node.go @@ -4,51 +4,17 @@ import ( "bytes" "math/rand" "net" - "strings" "sync" "time" "github.com/google/uuid" "github.com/gravitl/netmaker/schema" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -type NodeStatus string - -const ( - OnlineSt NodeStatus = "online" - OfflineSt NodeStatus = "offline" - WarningSt NodeStatus = "warning" - ErrorSt NodeStatus = "error" - UnKnown NodeStatus = "unknown" - Disconnected NodeStatus = "disconnected" + "gorm.io/datatypes" ) // LastCheckInThreshold - if node's checkin more than this threshold,then node is declared as offline const LastCheckInThreshold = time.Minute * 10 -const ( - // NODE_SERVER_NAME - the default server name - NODE_SERVER_NAME = "netmaker" - // MAX_NAME_LENGTH - max name length of node - MAX_NAME_LENGTH = 62 - // == ACTIONS == (can only be set by server) - // NODE_DELETE - delete node action - NODE_DELETE = "delete" - // NODE_IS_PENDING - node pending status - NODE_IS_PENDING = "pending" - // NODE_NOOP - node no op action - NODE_NOOP = "noop" - // NODE_FORCE_UPDATE - indicates a node should pull all changes - NODE_FORCE_UPDATE = "force" - // FIREWALL_IPTABLES - indicates that iptables is the firewall in use - FIREWALL_IPTABLES = "iptables" - // FIREWALL_NFTABLES - indicates nftables is in use (Linux only) - FIREWALL_NFTABLES = "nftables" - // FIREWALL_NONE - indicates that no supported firewall in use - FIREWALL_NONE = "none" -) - var seededRand *rand.Rand = rand.New( rand.NewSource(time.Now().UnixNano())) @@ -116,7 +82,7 @@ type Node struct { IsStatic bool `json:"is_static"` IsUserNode bool `json:"is_user_node"` StaticNode ExtClient `json:"static_node"` - Status NodeStatus `json:"node_status"` + Status schema.NodeStatus `json:"node_status"` Mutex *sync.Mutex `json:"-"` EgressDetails EgressDetails `json:"-"` PostureChecksViolations []Violation `json:"posture_check_violations"` @@ -135,66 +101,6 @@ type EgressDetails struct { // InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"` } -// LegacyNode - legacy struct for node model -type LegacyNode struct { - ID string `json:"id,omitempty" bson:"id,omitempty" yaml:"id,omitempty" validate:"required,min=5,id_unique"` - Address string `json:"address" bson:"address" yaml:"address" validate:"omitempty,ipv4"` - Address6 string `json:"address6" bson:"address6" yaml:"address6" validate:"omitempty,ipv6"` - LocalAddress string `json:"localaddress" bson:"localaddress" yaml:"localaddress" validate:"omitempty"` - Interfaces []schema.Iface `json:"interfaces" yaml:"interfaces"` - Name string `json:"name" bson:"name" yaml:"name" validate:"omitempty,max=62,in_charset"` - NetworkSettings Network `json:"networksettings" bson:"networksettings" yaml:"networksettings" validate:"-"` - ListenPort int32 `json:"listenport" bson:"listenport" yaml:"listenport" validate:"omitempty,numeric,min=1024,max=65535"` - LocalListenPort int32 `json:"locallistenport" bson:"locallistenport" yaml:"locallistenport" validate:"numeric,min=0,max=65535"` - PublicKey string `json:"publickey" bson:"publickey" yaml:"publickey" validate:"required,base64"` - Endpoint string `json:"endpoint" bson:"endpoint" yaml:"endpoint" validate:"required,ip"` - AllowedIPs []string `json:"allowedips" bson:"allowedips" yaml:"allowedips"` - PersistentKeepalive int32 `json:"persistentkeepalive" bson:"persistentkeepalive" yaml:"persistentkeepalive" validate:"omitempty,numeric,max=1000"` - IsHub string `json:"ishub" bson:"ishub" yaml:"ishub" validate:"checkyesorno"` - AccessKey string `json:"accesskey" bson:"accesskey" yaml:"accesskey"` - Interface string `json:"interface" bson:"interface" yaml:"interface"` - LastModified int64 `json:"lastmodified" bson:"lastmodified" yaml:"lastmodified" swaggertype:"primitive,integer" format:"int64"` - ExpirationDateTime int64 `json:"expdatetime" bson:"expdatetime" yaml:"expdatetime" swaggertype:"primitive,integer" format:"int64"` - LastPeerUpdate int64 `json:"lastpeerupdate" bson:"lastpeerupdate" yaml:"lastpeerupdate" swaggertype:"primitive,integer" format:"int64"` - LastCheckIn int64 `json:"lastcheckin" bson:"lastcheckin" yaml:"lastcheckin" swaggertype:"primitive,integer" format:"int64"` - MacAddress string `json:"macaddress" bson:"macaddress" yaml:"macaddress"` - Password string `json:"password" bson:"password" yaml:"password" validate:"required,min=6"` - Network string `json:"network" bson:"network" yaml:"network" validate:"network_exists"` - IsRelayed string `json:"isrelayed" bson:"isrelayed" yaml:"isrelayed"` - IsPending string `json:"ispending" bson:"ispending" yaml:"ispending"` - IsRelay string `json:"isrelay" bson:"isrelay" yaml:"isrelay" validate:"checkyesorno"` - IsDocker string `json:"isdocker" bson:"isdocker" yaml:"isdocker" validate:"checkyesorno"` - IsK8S string `json:"isk8s" bson:"isk8s" yaml:"isk8s" validate:"checkyesorno"` - IsEgressGateway string `json:"isegressgateway" bson:"isegressgateway" yaml:"isegressgateway" validate:"checkyesorno"` - IsIngressGateway string `json:"isingressgateway" bson:"isingressgateway" yaml:"isingressgateway" validate:"checkyesorno"` - EgressGatewayRanges []string `json:"egressgatewayranges" bson:"egressgatewayranges" yaml:"egressgatewayranges"` - EgressGatewayNatEnabled string `json:"egressgatewaynatenabled" bson:"egressgatewaynatenabled" yaml:"egressgatewaynatenabled"` - EgressGatewayRequest EgressGatewayRequest `json:"egressgatewayrequest" bson:"egressgatewayrequest" yaml:"egressgatewayrequest"` - RelayAddrs []string `json:"relayaddrs" bson:"relayaddrs" yaml:"relayaddrs"` - FailoverNode string `json:"failovernode" bson:"failovernode" yaml:"failovernode"` - IngressGatewayRange string `json:"ingressgatewayrange" bson:"ingressgatewayrange" yaml:"ingressgatewayrange"` - IngressGatewayRange6 string `json:"ingressgatewayrange6" bson:"ingressgatewayrange6" yaml:"ingressgatewayrange6"` - // IsStatic - refers to if the Endpoint is set manually or dynamically - IsStatic string `json:"isstatic" bson:"isstatic" yaml:"isstatic" validate:"checkyesorno"` - UDPHolePunch string `json:"udpholepunch" bson:"udpholepunch" yaml:"udpholepunch" validate:"checkyesorno"` - DNSOn string `json:"dnson" bson:"dnson" yaml:"dnson" validate:"checkyesorno"` - IsServer string `json:"isserver" bson:"isserver" yaml:"isserver" validate:"checkyesorno"` - Action string `json:"action" bson:"action" yaml:"action"` - IPForwarding string `json:"ipforwarding" bson:"ipforwarding" yaml:"ipforwarding" validate:"checkyesorno"` - OS string `json:"os" bson:"os" yaml:"os"` - MTU int32 `json:"mtu" bson:"mtu" yaml:"mtu"` - Version string `json:"version" bson:"version" yaml:"version"` - Server string `json:"server" bson:"server" yaml:"server"` - TrafficKeys TrafficKeys `json:"traffickeys" bson:"traffickeys" yaml:"traffickeys"` - FirewallInUse string `json:"firewallinuse" bson:"firewallinuse" yaml:"firewallinuse"` - InternetGateway string `json:"internetgateway" bson:"internetgateway" yaml:"internetgateway"` - Connected string `json:"connected" bson:"connected" yaml:"connected" validate:"checkyesorno"` - // == PRO == - DefaultACL string `json:"defaultacl,omitempty" bson:"defaultacl,omitempty" yaml:"defaultacl,omitempty" validate:"checkyesornoorunset"` - OwnerID string `json:"ownerid,omitempty" bson:"ownerid,omitempty" yaml:"ownerid,omitempty"` - Failover string `json:"failover" bson:"failover" yaml:"failover" validate:"checkyesorno"` -} - // NodesArray - used for node sorting type NodesArray []Node @@ -273,120 +179,6 @@ func (node *Node) SetDefaultConnected() { node.Connected = true } -// Node.SetDefaultACL -func (node *LegacyNode) SetDefaultACL() { - if node.DefaultACL == "" { - node.DefaultACL = "yes" - } -} - -// Node.SetDefaultMTU - sets default MTU of a node -func (node *LegacyNode) SetDefaultMTU() { - if node.MTU == 0 { - node.MTU = 1280 - } -} - -// Node.SetDefaultNFTablesPresent - sets default for nftables check -func (node *LegacyNode) SetDefaultNFTablesPresent() { - if node.FirewallInUse == "" { - node.FirewallInUse = FIREWALL_IPTABLES // default to iptables - } -} - -// Node.SetDefaultIsRelayed - set default is relayed -func (node *LegacyNode) SetDefaultIsRelayed() { - if node.IsRelayed == "" { - node.IsRelayed = "no" - } -} - -// Node.SetDefaultIsRelayed - set default is relayed -func (node *LegacyNode) SetDefaultIsHub() { - if node.IsHub == "" { - node.IsHub = "no" - } -} - -// Node.SetDefaultIsRelay - set default isrelay -func (node *LegacyNode) SetDefaultIsRelay() { - if node.IsRelay == "" { - node.IsRelay = "no" - } -} - -// Node.SetDefaultIsDocker - set default isdocker -func (node *LegacyNode) SetDefaultIsDocker() { - if node.IsDocker == "" { - node.IsDocker = "no" - } -} - -// Node.SetDefaultIsK8S - set default isk8s -func (node *LegacyNode) SetDefaultIsK8S() { - if node.IsK8S == "" { - node.IsK8S = "no" - } -} - -// Node.SetDefaultEgressGateway - sets default egress gateway status -func (node *LegacyNode) SetDefaultEgressGateway() { - if node.IsEgressGateway == "" { - node.IsEgressGateway = "no" - } -} - -// Node.SetDefaultIngressGateway - sets default ingress gateway status -func (node *LegacyNode) SetDefaultIngressGateway() { - if node.IsIngressGateway == "" { - node.IsIngressGateway = "no" - } -} - -// Node.SetDefaultAction - sets default action status -func (node *LegacyNode) SetDefaultAction() { - if node.Action == "" { - node.Action = NODE_NOOP - } -} - -// Node.SetRoamingDefault - sets default roaming status -//func (node *Node) SetRoamingDefault() { -// if node.Roaming == "" { -// node.Roaming = "yes" -// } -//} - -// Node.SetIPForwardingDefault - set ip forwarding default -func (node *LegacyNode) SetIPForwardingDefault() { - if node.IPForwarding == "" { - node.IPForwarding = "yes" - } -} - -// Node.SetDNSOnDefault - sets dns on default -func (node *LegacyNode) SetDNSOnDefault() { - if node.DNSOn == "" { - node.DNSOn = "yes" - } -} - -// Node.SetIsServerDefault - sets node isserver default -func (node *LegacyNode) SetIsServerDefault() { - if node.IsServer != "yes" { - node.IsServer = "no" - } -} - -// Node.SetIsStaticDefault - set is static default -func (node *LegacyNode) SetIsStaticDefault() { - if node.IsServer == "yes" { - node.IsStatic = "yes" - } else if node.IsStatic != "yes" { - node.IsStatic = "no" - } -} - // Node.SetLastModified - set last modified initial time func (node *Node) SetLastModified() { node.LastModified = time.Now().UTC() @@ -409,20 +201,6 @@ func (node *Node) SetExpirationDateTime() { } } -// Node.SetDefaultName - sets a random name to node -func (node *LegacyNode) SetDefaultName() { - if node.Name == "" { - node.Name = GenerateNodeName() - } -} - -// Node.SetDefaultFailover - sets default value of failover status to no if not set -func (node *LegacyNode) SetDefaultFailover() { - if node.Failover == "" { - node.Failover = "no" - } -} - // Node.Fill - fills other node data into calling node data if not set on calling node (skips DNSOn) func (newNode *Node) Fill( currentNode *Node, @@ -490,139 +268,6 @@ func (newNode *Node) Fill( } } -// StringWithCharset - returns random string inside defined charset -func StringWithCharset(length int, charset string) string { - b := make([]byte, length) - for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] - } - return string(b) -} - -// IsIpv4Net - check for valid IPv4 address -// Note: We dont handle IPv6 AT ALL!!!!! This definitely is needed at some point -// But for iteration 1, lets just stick to IPv4. Keep it simple stupid. -func IsIpv4Net(host string) bool { - return net.ParseIP(host) != nil -} - -// Node.NameInNodeCharset - returns if name is in charset below or not -func (node *LegacyNode) NameInNodeCharSet() bool { - - charset := "abcdefghijklmnopqrstuvwxyz1234567890-" - - for _, char := range node.Name { - if !strings.Contains(charset, strings.ToLower(string(char))) { - return false - } - } - return true -} - -func (ln *LegacyNode) ConvertToNewNode() (*schema.Host, *Node) { - var node Node - //host:= logic.GetHost(node.HostID) - var host schema.Host - if host.ID.String() == "" { - host.ID = uuid.New() - host.FirewallInUse = ln.FirewallInUse - host.Version = ln.Version - host.IPForwarding = parseBool(ln.IPForwarding) - host.HostPass = ln.Password - host.Name = ln.Name - host.ListenPort = int(ln.ListenPort) - host.MTU = int(ln.MTU) - pubkey, _ := wgtypes.ParseKey(ln.PublicKey) - host.PublicKey = schema.WgKey{Key: pubkey} - host.MacAddress, _ = net.ParseMAC(ln.MacAddress) - host.TrafficKeyPublic = ln.TrafficKeys.Mine - id, _ := uuid.Parse(ln.ID) - host.Nodes = append(host.Nodes, id.String()) - host.Interfaces = ln.Interfaces - host.EndpointIP = net.ParseIP(ln.Endpoint) - // host.ProxyEnabled = ln.Proxy // this will always be false.. - } - id, _ := uuid.Parse(ln.ID) - node.ID = id - node.Network = ln.Network - if _, cidr, err := net.ParseCIDR(ln.NetworkSettings.AddressRange); err == nil { - node.NetworkRange = *cidr - } - if _, cidr, err := net.ParseCIDR(ln.NetworkSettings.AddressRange6); err == nil { - node.NetworkRange6 = *cidr - } - node.Server = ln.Server - node.Connected = parseBool(ln.Connected) - if ln.Address != "" { - node.Address = net.IPNet{ - IP: net.ParseIP(ln.Address), - Mask: net.CIDRMask(32, 32), - } - } - if ln.Address6 != "" { - node.Address = net.IPNet{ - IP: net.ParseIP(ln.Address6), - Mask: net.CIDRMask(128, 128), - } - } - node.Action = ln.Action - node.IsIngressGateway = parseBool(ln.IsIngressGateway) - - return &host, &node -} - -// Node.Legacy converts node to legacy format -func (n *Node) Legacy(h *Host, s *ServerConfig, net *Network) *LegacyNode { - l := LegacyNode{} - l.ID = n.ID.String() - //l.HostID = h.ID.String() - l.Address = n.Address.String() - l.Address6 = n.Address6.String() - l.Interfaces = h.Interfaces - l.Name = h.Name - l.NetworkSettings = *net - l.ListenPort = int32(h.ListenPort) - l.PublicKey = h.PublicKey.String() - l.Endpoint = h.EndpointIP.String() - //l.AllowedIPs = - l.AccessKey = "" - l.Interface = WIREGUARD_INTERFACE - //l.LastModified = - //l.ExpirationDateTime - //l.LastPeerUpdate - //l.LastCheckIn - l.MacAddress = h.MacAddress.String() - l.Password = h.HostPass - l.Network = n.Network - //l.IsRelayed = formatBool(n.Is) - //l.IsRelay = formatBool(n.IsRelay) - //l.IsDocker = formatBool(n.IsDocker) - //l.IsK8S = formatBool(n.IsK8S) - l.IsIngressGateway = formatBool(n.IsIngressGateway) - //l.EgressGatewayRanges = n.EgressGatewayRanges - //l.EgressGatewayNatEnabled = n.EgressGatewayNatEnabled - //l.RelayAddrs = n.RelayAddrs - //l.FailoverNode = n.FailoverNode - //l.IngressGatewayRange = n.IngressGatewayRange - //l.IngressGatewayRange6 = n.IngressGatewayRange6 - l.IsStatic = formatBool(h.IsStatic) - l.UDPHolePunch = formatBool(true) - l.Action = n.Action - l.IPForwarding = formatBool(h.IPForwarding) - l.OS = h.OS - l.MTU = int32(h.MTU) - l.Version = h.Version - l.Server = n.Server - l.TrafficKeys.Mine = h.TrafficKeyPublic - l.TrafficKeys.Server = s.TrafficKey - l.FirewallInUse = h.FirewallInUse - l.Connected = formatBool(n.Connected) - //l.PendingDelete = formatBool(n.PendingDelete) - l.OwnerID = n.OwnerID - //l.Failover = n.Failover - return &l -} - // Node.NetworkSettings updates a node with network settings func (node *Node) NetworkSettings(n Network) { _, cidr, err := net.ParseCIDR(n.AddressRange) @@ -635,18 +280,62 @@ func (node *Node) NetworkSettings(n Network) { } } -func parseBool(s string) bool { - b := false - if s == "yes" { - b = true - } - return b -} - -func formatBool(b bool) string { - s := "no" - if b { - s = "yes" - } - return s +type NodeWithHost struct { + ID string `json:"id"` + HostID string `json:"host_id"` + Host *ApiHost `json:"host,omitempty"` + NetworkID string `json:"network_id"` + Address string `json:"address"` + Address6 string `json:"address6"` + Connected bool `json:"connected"` + Action string `json:"action"` + Status schema.NodeStatus `json:"status"` + PendingDelete bool `json:"pending_delete"` + AutoAssignGateway bool `json:"auto_assign_gateway"` + IsGateway bool `json:"is_gateway"` + IsAutoRelay bool `json:"is_auto_relay"` + IsInternetGateway bool `json:"is_internet_gateway"` + RelayedClients datatypes.JSONMap `json:"relayed_clients"` + RelayedIGWClients datatypes.JSONMap `json:"relayed_igw_clients"` + RelayingNodeID *string `json:"relaying_node_id"` + IsIGWClient bool `json:"is_igw_client"` + AutoRelayedPeers datatypes.JSONType[map[string]string] `json:"auto_relayed_peers"` + Tags datatypes.JSONMap `json:"tags"` + PostureCheckSeverity schema.Severity `json:"posture_check_severity"` + PostureCheckLastEvaluationCycleID string `json:"posture_check_last_evaluation_cycle_id"` + Metadata string `json:"metadata"` + LastCheckIn time.Time `json:"last_check_in"` + ExpirationDateTime time.Time `json:"expiration_date_time"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (n *NodeWithHost) Fill(_node *schema.Node) { + n.ID = _node.ID + n.HostID = _node.HostID + n.Host = NewApiHostFromSchemaHost(_node.Host) + n.NetworkID = _node.NetworkID + n.Address = _node.Address + n.Address6 = _node.Address6 + n.Connected = _node.Connected + n.Action = _node.Action + n.Status = _node.Status + n.PendingDelete = _node.PendingDelete + n.AutoAssignGateway = _node.AutoAssignGateway + n.IsGateway = _node.IsGateway + n.IsAutoRelay = _node.IsAutoRelay + n.IsInternetGateway = _node.IsInternetGateway + n.RelayedClients = _node.RelayedClients + n.RelayedIGWClients = _node.RelayedIGWClients + n.RelayingNodeID = _node.RelayingNodeID + n.IsIGWClient = _node.IsIGWClient + n.AutoRelayedPeers = _node.AutoRelayedPeers + n.Tags = _node.Tags + n.PostureCheckSeverity = _node.PostureCheckSeverity + n.PostureCheckLastEvaluationCycleID = _node.PostureCheckLastEvaluationCycleID + n.Metadata = _node.Metadata + n.LastCheckIn = _node.LastCheckIn + n.ExpirationDateTime = _node.ExpirationDateTime + n.CreatedAt = _node.CreatedAt + n.UpdatedAt = _node.UpdatedAt } diff --git a/models/structs.go b/models/structs.go index b8f278d7e..d5cb9819b 100644 --- a/models/structs.go +++ b/models/structs.go @@ -4,18 +4,11 @@ import ( "net" "time" - jwt "github.com/golang-jwt/jwt/v4" + "github.com/golang-jwt/jwt/v4" "github.com/gravitl/netmaker/schema" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -const ( - // PLACEHOLDER_KEY_TEXT - access key placeholder text if option turned off - PLACEHOLDER_KEY_TEXT = "ACCESS_KEY" - // PLACEHOLDER_TOKEN_TEXT - access key token placeholder text if option turned off - PLACEHOLDER_TOKEN_TEXT = "ACCESS_TOKEN" -) - type FeatureFlags struct { EnableEgressHA bool `json:"enable_egress_ha"` EnableNetworkActivity bool `json:"enable_network_activity"` @@ -46,24 +39,24 @@ type IngressGwUsers struct { // UserRemoteGws - struct to hold user's remote gws type UserRemoteGws struct { - GwID string `json:"remote_access_gw_id"` - GWName string `json:"gw_name"` - Network string `json:"network"` - Connected bool `json:"connected"` - IsInternetGateway bool `json:"is_internet_gateway"` - GwClient ExtClient `json:"gw_client"` - GwPeerPublicKey string `json:"gw_peer_public_key"` - GwListenPort int `json:"gw_listen_port"` - Metadata string `json:"metadata"` - AllowedEndpoints []string `json:"allowed_endpoints"` - NetworkAddresses []string `json:"network_addresses"` - Status NodeStatus `json:"status"` - ManageDNS bool `json:"manage_dns"` - DnsAddress string `json:"dns_address"` - Addresses string `json:"addresses"` - MatchDomains []string `json:"match_domains"` - SearchDomains []string `json:"search_domains"` - Nameservers []Nameserver `json:"nameservers"` + GwID string `json:"remote_access_gw_id"` + GWName string `json:"gw_name"` + Network string `json:"network"` + Connected bool `json:"connected"` + IsInternetGateway bool `json:"is_internet_gateway"` + GwClient ExtClient `json:"gw_client"` + GwPeerPublicKey string `json:"gw_peer_public_key"` + GwListenPort int `json:"gw_listen_port"` + Metadata string `json:"metadata"` + AllowedEndpoints []string `json:"allowed_endpoints"` + NetworkAddresses []string `json:"network_addresses"` + Status schema.NodeStatus `json:"status"` + ManageDNS bool `json:"manage_dns"` + DnsAddress string `json:"dns_address"` + Addresses string `json:"addresses"` + MatchDomains []string `json:"match_domains"` + SearchDomains []string `json:"search_domains"` + Nameservers []Nameserver `json:"nameservers"` } // UserRAGs - struct for user access gws @@ -246,13 +239,6 @@ type InetNodeReq struct { InetNodeClientIDs []string `json:"inet_node_client_ids"` } -// ServerUpdateData - contains data to configure server -// and if it should set peers -type ServerUpdateData struct { - UpdatePeers bool `json:"updatepeers" bson:"updatepeers"` - Node LegacyNode `json:"servernode" bson:"servernode"` -} - // Telemetry - contains UUID of the server and timestamp of last send to posthog // also contains assymetrical encryption pub/priv keys for any server traffic type Telemetry struct { @@ -472,17 +458,17 @@ type IDPSyncTestRequest struct { } type PostureCheckDeviceInfo struct { - ClientLocation string - ClientVersion string - OS string - OSFamily string - OSVersion string - KernelVersion string - AutoUpdate bool - SkipAutoUpdate bool - Tags map[TagID]struct{} - IsUser bool - UserGroups map[schema.UserGroupID]struct{} + ClientLocation string + ClientVersion string + OS string + OSFamily string + OSVersion string + KernelVersion string + AutoUpdate bool + SkipAutoUpdate bool + Tags map[TagID]struct{} + IsUser bool + UserGroups map[schema.UserGroupID]struct{} } type Violation struct { @@ -512,7 +498,6 @@ type BulkUserStatusUpdate struct { Disable bool `json:"disable"` } - type BulkStatusResponse struct { Updated []string `json:"updated"` Failed []BulkDeleteError `json:"failed,omitempty"` diff --git a/mq/handlers.go b/mq/handlers.go index 69fe06b31..d8a477680 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -191,9 +191,6 @@ func DeleteAndCleanupHost(h *schema.Host) { slog.Error("failed to delete host", "id", h.ID, "error", err) return } - if servercfg.IsDNSMode() { - logic.SetDNS() - } } func SignalPeer(signal models.Signal) { @@ -274,7 +271,7 @@ func HandleHostCheckin(h, currentHost *schema.Host) bool { if database.IsEmptyRecord(err) { fakeNode := models.Node{} fakeNode.ID, _ = uuid.Parse(currNodeID) - fakeNode.Action = models.NODE_DELETE + fakeNode.Action = schema.NODE_DELETE fakeNode.PendingDelete = true if err := NodeUpdate(&fakeNode); err != nil { slog.Warn("failed to inform host to remove node", "host", currentHost.Name, "hostid", currentHost.ID, "nodeid", currNodeID, "error", err) @@ -282,7 +279,7 @@ func HandleHostCheckin(h, currentHost *schema.Host) bool { } continue } - if err := logic.UpdateNodeCheckin(&node); err != nil { + if err := logic.UpdateNodeCheckin(node.ID.String()); err != nil { slog.Warn("failed to update node on checkin", "nodeid", node.ID, "error", err) } } diff --git a/mq/publishers.go b/mq/publishers.go index b68fc4de4..bb21acdd7 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -300,7 +300,7 @@ func ServerStartNotify() error { return err } for i := range nodes { - nodes[i].Action = models.NODE_FORCE_UPDATE + nodes[i].Action = schema.NODE_FORCE_UPDATE if err = NodeUpdate(&nodes[i]); err != nil { logger.Log(1, "error when notifying node", nodes[i].ID.String(), "of a server startup") } @@ -312,7 +312,7 @@ func ServerStartNotify() error { func PublishMqUpdatesForDeletedNode(node models.Node, sendNodeUpdate bool) { // notify of peer change node.PendingDelete = true - node.Action = models.NODE_DELETE + node.Action = schema.NODE_DELETE if sendNodeUpdate { if err := NodeUpdate(&node); err != nil { slog.Error("error publishing node update to node", "node", node.ID, "error", err) @@ -321,10 +321,6 @@ func PublishMqUpdatesForDeletedNode(node models.Node, sendNodeUpdate bool) { if err := PublishDeletedNodePeerUpdate(&node); err != nil { logger.Log(1, "error publishing peer update ", err.Error()) } - if servercfg.IsDNSMode() { - logic.SetDNS() - } - } // PushAllMetricsToExporter fetches all node metrics from the database diff --git a/orchestrator/extensions/factory.go b/orchestrator/extensions/factory.go new file mode 100644 index 000000000..c24bf243a --- /dev/null +++ b/orchestrator/extensions/factory.go @@ -0,0 +1,19 @@ +package extensions + +type Factory struct { + nodeExt NodeExtensions +} + +func NewFactory(nodeExt NodeExtensions) *Factory { + return &Factory{ + nodeExt: nodeExt, + } +} + +func NewCEFactory() *Factory { + return NewFactory(&CENodeExtensions{}) +} + +func (f *Factory) NodeExtensions() NodeExtensions { + return f.nodeExt +} diff --git a/orchestrator/extensions/node.go b/orchestrator/extensions/node.go new file mode 100644 index 000000000..cdc4dc199 --- /dev/null +++ b/orchestrator/extensions/node.go @@ -0,0 +1,21 @@ +package extensions + +import ( + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" +) + +type NodeExtensions interface { + ConfigureAutoRelay(node *schema.Node) + ConfigureAutoAssignGateway(node *schema.Node, key *models.EnrollmentKey) +} + +type CENodeExtensions struct{} + +func (c *CENodeExtensions) ConfigureAutoRelay(_ *schema.Node) { + return +} + +func (c *CENodeExtensions) ConfigureAutoAssignGateway(node *schema.Node, _ *models.EnrollmentKey) { + node.AutoAssignGateway = false +} diff --git a/orchestrator/network.go b/orchestrator/network.go new file mode 100644 index 000000000..a4a0af9c8 --- /dev/null +++ b/orchestrator/network.go @@ -0,0 +1,160 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/c-robinson/iplib" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/schema" +) + +type NetworkOrchestrator struct { + addressLock sync.Mutex + address6Lock sync.Mutex +} + +func (n *NetworkOrchestrator) AllocateNodeIP(ctx context.Context, network *schema.Network) (net.IP, error) { + return n.allocateIPv4(ctx, network, false) +} + +func (n *NetworkOrchestrator) AllocateExtclientIP(ctx context.Context, network *schema.Network) (net.IP, error) { + return n.allocateIPv4(ctx, network, true) +} + +func (n *NetworkOrchestrator) AllocateNodeIPv6(ctx context.Context, network *schema.Network) (net.IP, error) { + return n.allocateIPv6(ctx, network, false) +} + +func (n *NetworkOrchestrator) AllocateExtclientIPv6(ctx context.Context, network *schema.Network) (net.IP, error) { + return n.allocateIPv6(ctx, network, true) +} + +func (n *NetworkOrchestrator) allocateIPv4(ctx context.Context, network *schema.Network, reverse bool) (net.IP, error) { + n.addressLock.Lock() + defer n.addressLock.Unlock() + + if network.AddressRange == "" { + return nil, fmt.Errorf("IPv4 not configured on network %s", network.Name) + } + if _, _, err := net.ParseCIDR(network.AddressRange); err != nil { + return nil, err + } + return n.findUniqueIPv4DB(ctx, network, reverse) +} + +func (n *NetworkOrchestrator) allocateIPv6(ctx context.Context, network *schema.Network, reverse bool) (net.IP, error) { + n.address6Lock.Unlock() + defer n.address6Lock.Lock() + + if network.AddressRange6 == "" { + return nil, fmt.Errorf("IPv6 not configured on network %s", network.Name) + } + if _, _, err := net.ParseCIDR(network.AddressRange6); err != nil { + return nil, err + } + return n.findUniqueIPv6DB(ctx, network, reverse) +} + +func (n *NetworkOrchestrator) findUniqueIPv4DB(ctx context.Context, network *schema.Network, reverse bool) (net.IP, error) { + net4 := iplib.Net4FromStr(network.AddressRange) + addr := net4.FirstAddress() + if reverse { + addr = net4.LastAddress() + } + + for { + if n.IsIPv4Unique(ctx, network, addr.String()) { + return addr, nil + } + var err error + if reverse { + addr, err = net4.PreviousIP(addr) + } else { + addr, err = net4.NextIP(addr) + } + if err != nil { + return nil, errors.New("no unique IPv4 addresses available") + } + } +} + +func (n *NetworkOrchestrator) findUniqueIPv6DB(ctx context.Context, network *schema.Network, reverse bool) (net.IP, error) { + net6 := iplib.Net6FromStr(network.AddressRange6) + + var ( + addr net.IP + err error + ) + if reverse { + addr, err = net6.PreviousIP(net6.LastAddress()) + } else { + addr, err = net6.NextIP(net6.FirstAddress()) + } + if err != nil { + return nil, err + } + + for { + if n.IsIPv6Unique(ctx, network, addr.String()) { + return addr, nil + } + if reverse { + addr, err = net6.PreviousIP(addr) + } else { + addr, err = net6.NextIP(addr) + } + if err != nil { + return nil, errors.New("no unique IPv6 addresses available") + } + } +} + +func (n *NetworkOrchestrator) IsIPv4Unique(ctx context.Context, network *schema.Network, ip string) bool { + _, cidr, err := net.ParseCIDR(network.AddressRange) + if err != nil { + return true + } + cidr.IP = net.ParseIP(ip) + node := &schema.Node{NetworkID: network.ID, Address: cidr.String()} + if err := node.GetByNetworkAndAddress(ctx); err == nil { + return false + } + + extClients, err := logic.GetNetworkExtClients(network.Name) + if err != nil { + return true + } + for _, ec := range extClients { + if ec.Address == ip { + return false + } + } + return true +} + +func (n *NetworkOrchestrator) IsIPv6Unique(ctx context.Context, network *schema.Network, ip string) bool { + _, cidr, err := net.ParseCIDR(network.AddressRange6) + if err != nil { + return true + } + cidr.IP = net.ParseIP(ip) + node := &schema.Node{NetworkID: network.ID, Address6: cidr.String()} + if err := node.GetByNetworkAndAddress6(ctx); err == nil { + return false + } + + extClients, err := logic.GetNetworkExtClients(network.Name) + if err != nil { + return true + } + for _, ec := range extClients { + if ec.Address6 == ip { + return false + } + } + return true +} diff --git a/orchestrator/node.go b/orchestrator/node.go new file mode 100644 index 000000000..7b9e7f0c7 --- /dev/null +++ b/orchestrator/node.go @@ -0,0 +1,208 @@ +package orchestrator + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/google/uuid" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/orchestrator/extensions" + "github.com/gravitl/netmaker/schema" + "gorm.io/datatypes" + "gorm.io/gorm" +) + +type NodeOrchestrator struct { + nodeExt extensions.NodeExtensions +} + +type NodeOrchestratorOptions struct { + useKey bool + key *models.EnrollmentKey + skipPublishPeerUpdate bool +} + +type NodeOrchestratorOption func(options *NodeOrchestratorOptions) *NodeOrchestratorOptions + +func UseKey(key *models.EnrollmentKey) NodeOrchestratorOption { + return func(o *NodeOrchestratorOptions) *NodeOrchestratorOptions { + o.useKey = true + o.key = key + return o + } +} + +func SkipPublishPeerUpdate() NodeOrchestratorOption { + return func(o *NodeOrchestratorOptions) *NodeOrchestratorOptions { + o.skipPublishPeerUpdate = true + return o + } +} + +func (n *NodeOrchestrator) CreateNode(ctx context.Context, host *schema.Host, network *schema.Network, options ...NodeOrchestratorOption) (*schema.Node, error) { + var ops NodeOrchestratorOptions + for _, option := range options { + option(&ops) + } + + node := &schema.Node{ + ID: uuid.NewString(), + HostID: host.ID.String(), + Host: host, + NetworkID: network.ID, + Network: network, + Connected: true, + LastCheckIn: time.Now(), + ExpirationDateTime: time.Now().AddDate(100, 1, 0), + AutoRelayedPeers: datatypes.NewJSONType(make(map[string]string)), + Tags: make(datatypes.JSONMap), + } + + if ops.useKey { + n.nodeExt.ConfigureAutoAssignGateway(node, ops.key) + + for _, tag := range ops.key.Groups { + node.Tags[string(tag)] = true + } + } + + // TODO: Ensure concurrency safe ip allocation. + if network.AddressRange != "" { + ip, err := GetRepository().NetworkOrchestrator().AllocateNodeIP(ctx, network) + if err != nil { + return nil, err + } + _, cidr, err := net.ParseCIDR(network.AddressRange) + if err != nil { + return nil, err + } + cidr.IP = ip + node.Address = cidr.String() + } + + if network.AddressRange6 != "" { + ip, err := GetRepository().NetworkOrchestrator().AllocateNodeIPv6(ctx, network) + if err != nil { + return nil, err + } + _, cidr, err := net.ParseCIDR(network.AddressRange6) + if err != nil { + return nil, err + } + cidr.IP = ip + node.Address6 = cidr.String() + } + + err := node.Create(ctx) + if err != nil { + return nil, err + } + + host.Nodes = append(host.Nodes, node.ID) + err = host.Upsert(ctx) + if err != nil { + return nil, err + } + + go logic.CheckZombies(node) + + go func() { + err := logic.UpdateMetrics(node.ID, &models.Metrics{Connectivity: make(map[string]models.Metric)}) + if err != nil { + logger.Log(1, fmt.Sprintf("failed to initialize metrics for node (%s): %v", node.ID, err)) + } + }() + + if host.IsDefault { + err = n.CreateGateway(ctx, node) + if err != nil { + return nil, err + } + } else if ops.useKey && ops.key.Relay != uuid.Nil { + gateway := &schema.Node{ + ID: ops.key.Relay.String(), + } + err = gateway.Get(ctx) + if err == nil { + relayID := ops.key.Relay.String() + node.RelayingNodeID = &relayID + err = node.UpdateRelayingNode(ctx) + if err != nil { + return nil, err + } + + gateway.RelayedClients[node.ID] = struct{}{} + err = gateway.UpdateRelayedClients(ctx) + if err != nil { + return nil, err + } + } else if !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, err + } + } + + action := models.JoinHostToNetwork + if len(host.Nodes) == 1 { + action = models.RequestPull + } + + // TODO: figure out mq placement. + go func() { + if err := mq.HostUpdate(&models.HostUpdate{ + Action: action, + Host: *host, + Node: *logic.ConvertSchemaNodeToModelsNode(node), + }); err != nil { + logger.Log(1, "failed to send host update for node", node.ID, err.Error()) + } + }() + + if !ops.skipPublishPeerUpdate { + go func() { + if err := mq.PublishPeerUpdate(false); err != nil { + logger.Log(1, "failed to publish peer update for node", node.ID, err.Error()) + } + }() + } + + return node, nil +} + +func (n *NodeOrchestrator) CreateGateway(ctx context.Context, node *schema.Node) error { + if node.Host.OS != "linux" { + return errors.New("gateway can only be created on linux based node") + } + + if node.IsGateway { + return errors.New("node is already a gateway") + } + + if node.RelayingNodeID != nil { + return errors.New("gateway cannot be created on a relayed node") + } + + node.IsGateway = true + node.IsInternetGateway = false + + n.nodeExt.ConfigureAutoRelay(node) + + err := node.Update(ctx) + if err != nil { + return err + } + + node.Tags[fmt.Sprintf("%s.%s", node.NetworkID, models.GwTagName)] = struct{}{} + err = node.UpdateTags(ctx) + if err != nil { + return err + } + + node.Network.NodesUpdatedAt = time.Now() + return node.Network.UpdateNodesUpdatedAt(ctx) +} diff --git a/orchestrator/repository.go b/orchestrator/repository.go new file mode 100644 index 000000000..fc9e76fe1 --- /dev/null +++ b/orchestrator/repository.go @@ -0,0 +1,38 @@ +package orchestrator + +import ( + "sync" + + "github.com/gravitl/netmaker/orchestrator/extensions" +) + +var repo *Repository +var once sync.Once + +type Repository struct { + network *NetworkOrchestrator + node *NodeOrchestrator +} + +func InitializeRepository(extFactory *extensions.Factory) { + once.Do(func() { + repo = &Repository{ + network: &NetworkOrchestrator{}, + node: &NodeOrchestrator{ + nodeExt: extFactory.NodeExtensions(), + }, + } + }) +} + +func GetRepository() *Repository { + return repo +} + +func (r *Repository) NetworkOrchestrator() *NetworkOrchestrator { + return r.network +} + +func (r *Repository) NodeOrchestrator() *NodeOrchestrator { + return r.node +} diff --git a/pro/auth/sync.go b/pro/auth/sync.go index b710a93cf..afd8e42f7 100644 --- a/pro/auth/sync.go +++ b/pro/auth/sync.go @@ -466,9 +466,6 @@ func deleteAndCleanUpUser(user *schema.User) error { go logic.DeleteUserInvite(user.Username) go mq.PublishPeerUpdate(false) - if servercfg.IsDNSMode() { - go logic.SetDNS() - } }() return nil diff --git a/pro/controllers/failover.go b/pro/controllers/failover.go deleted file mode 100644 index 83fe3828a..000000000 --- a/pro/controllers/failover.go +++ /dev/null @@ -1,443 +0,0 @@ -package controllers - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - - "github.com/google/uuid" - "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" - "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/mq" - proLogic "github.com/gravitl/netmaker/pro/logic" - "github.com/gravitl/netmaker/schema" - "golang.org/x/exp/slog" -) - -// FailOverHandlers - handlers for FailOver -func FailOverHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/node/{nodeid}/failover", controller.AuthorizeHost(http.HandlerFunc(getfailOver))). - Methods(http.MethodGet) - r.HandleFunc("/api/v1/node/{nodeid}/failover", logic.SecurityCheck(true, http.HandlerFunc(createfailOver))). - Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/failover", logic.SecurityCheck(true, http.HandlerFunc(deletefailOver))). - Methods(http.MethodDelete) - r.HandleFunc("/api/v1/node/{network}/failover/reset", logic.SecurityCheck(true, http.HandlerFunc(resetFailOver))). - Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/failover_me", controller.AuthorizeHost(http.HandlerFunc(failOverME))). - Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/failover_check", controller.AuthorizeHost(http.HandlerFunc(checkfailOverCtx))). - Methods(http.MethodGet) -} - -func getfailOver(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - nodeid := params["nodeid"] - // confirm host exists - node, err := logic.GetNodeByID(nodeid) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - - failOverNode, exists := proLogic.FailOverExists(node.Network) - if !exists { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("failover node not found"), "notfound"), - ) - return - } - w.Header().Set("Content-Type", "application/json") - logic.ReturnSuccessResponseWithJson(w, r, failOverNode, "get failover node successfully") -} - -func createfailOver(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - nodeid := params["nodeid"] - // confirm host exists - node, err := logic.GetNodeByID(nodeid) - if err != nil { - slog.Error("failed to get node:", "error", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - err = proLogic.CreateFailOver(node) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - go mq.PublishPeerUpdate(false) - w.Header().Set("Content-Type", "application/json") - logic.ReturnSuccessResponseWithJson(w, r, node, "created failover successfully") -} - -func resetFailOver(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - net := params["network"] - nodes, err := logic.GetNetworkNodes(net) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - for _, node := range nodes { - if node.FailedOverBy != uuid.Nil { - node.FailedOverBy = uuid.Nil - if node.Mutex != nil { - node.Mutex.Lock() - } - node.FailOverPeers = make(map[string]struct{}) - if node.Mutex != nil { - node.Mutex.Unlock() - } - logic.UpsertNode(&node) - } - } - go mq.PublishPeerUpdate(false) - w.Header().Set("Content-Type", "application/json") - logic.ReturnSuccessResponse(w, r, "failover has been reset successfully") -} - -func deletefailOver(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - nodeid := params["nodeid"] - // confirm host exists - node, err := logic.GetNodeByID(nodeid) - if err != nil { - slog.Error("failed to get node:", "error", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - node.IsFailOver = false - // Reset FailOvered Peers - err = logic.UpsertNode(&node) - if err != nil { - slog.Error("failed to upsert node", "node", node.ID.String(), "error", err) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) - return - } - proLogic.RemoveFailOverFromCache(node.Network) - go func() { - proLogic.ResetFailOver(&node) - mq.PublishPeerUpdate(false) - }() - w.Header().Set("Content-Type", "application/json") - logic.ReturnSuccessResponseWithJson(w, r, node, "deleted failover successfully") -} - -func failOverME(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - nodeid := params["nodeid"] - // confirm host exists - node, err := logic.GetNodeByID(nodeid) - if err != nil { - logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - host := &schema.Host{ - ID: node.HostID, - } - err = host.Get(r.Context()) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - - failOverNode, exists := proLogic.FailOverExists(node.Network) - if !exists { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - fmt.Errorf("req-from: %s, failover node doesn't exist in the network", host.Name), - "badrequest", - ), - ) - return - } - var failOverReq models.FailOverMeReq - err = json.NewDecoder(r.Body).Decode(&failOverReq) - if err != nil { - logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - var sendPeerUpdate bool - peerNode, err := logic.GetNodeByID(failOverReq.NodeID) - if err != nil { - slog.Error("peer not found: ", "nodeid", failOverReq.NodeID, "error", err) - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("peer not found"), "badrequest"), - ) - return - } - eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO())) - acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network)) - logic.GetNodeEgressInfo(&node, eli, acls) - logic.GetNodeEgressInfo(&peerNode, eli, acls) - logic.GetNodeEgressInfo(&failOverNode, eli, acls) - if peerNode.IsFailOver { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("peer is acting as failover"), "badrequest"), - ) - return - } - if node.IsFailOver { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("node is acting as failover"), "badrequest"), - ) - return - } - if peerNode.IsFailOver { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("peer is acting as failover"), "badrequest"), - ) - return - } - if node.IsRelayed && node.RelayedBy == peerNode.ID.String() { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"), - ) - return - } - if node.IsRelay && peerNode.RelayedBy == node.ID.String() { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"), - ) - return - } - if (node.InternetGwID != "" && failOverNode.IsInternetGateway && node.InternetGwID != failOverNode.ID.String()) || - (peerNode.InternetGwID != "" && failOverNode.IsInternetGateway && peerNode.InternetGwID != failOverNode.ID.String()) { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - errors.New("node using a internet gw by the peer node"), - "badrequest", - ), - ) - return - } - if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - errors.New("node acting as internet gw for the peer node"), - "badrequest", - ), - ) - return - } - if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - errors.New("node using a internet gw by the peer node"), - "badrequest", - ), - ) - return - } - err = proLogic.SetFailOverCtx(failOverNode, node, peerNode) - if err != nil { - slog.Debug("failed to create failover", "id", node.ID.String(), - "network", node.Network, "error", err) - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(fmt.Errorf("failed to create failover: %v", err), "internal"), - ) - return - } - slog.Info( - "[auto-relay] created relay on node", - "node", - node.ID.String(), - "network", - node.Network, - ) - sendPeerUpdate = true - - if sendPeerUpdate { - go mq.PublishPeerUpdate(false) - } - - w.Header().Set("Content-Type", "application/json") - logic.ReturnSuccessResponse(w, r, "relayed successfully") -} - -func checkfailOverCtx(w http.ResponseWriter, r *http.Request) { - var params = mux.Vars(r) - nodeid := params["nodeid"] - // confirm host exists - node, err := logic.GetNodeByID(nodeid) - if err != nil { - logger.Log(0, r.Header.Get("user"), "failed to get node:", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - host := &schema.Host{ - ID: node.HostID, - } - err = host.Get(r.Context()) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - - failOverNode, exists := proLogic.FailOverExists(node.Network) - if !exists { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - fmt.Errorf("req-from: %s, failover node doesn't exist in the network", host.Name), - "badrequest", - ), - ) - return - } - var failOverReq models.FailOverMeReq - err = json.NewDecoder(r.Body).Decode(&failOverReq) - if err != nil { - logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) - return - } - peerNode, err := logic.GetNodeByID(failOverReq.NodeID) - if err != nil { - slog.Error("peer not found: ", "nodeid", failOverReq.NodeID, "error", err) - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("peer not found"), "badrequest"), - ) - return - } - eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO())) - acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network)) - logic.GetNodeEgressInfo(&node, eli, acls) - logic.GetNodeEgressInfo(&peerNode, eli, acls) - logic.GetNodeEgressInfo(&failOverNode, eli, acls) - if peerNode.IsFailOver { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("peer is acting as failover"), "badrequest"), - ) - return - } - if node.IsFailOver { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("node is acting as failover"), "badrequest"), - ) - return - } - if peerNode.IsFailOver { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("peer is acting as failover"), "badrequest"), - ) - return - } - if node.IsRelayed && node.RelayedBy == peerNode.ID.String() { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("node is relayed by peer node"), "badrequest"), - ) - return - } - if node.IsRelay && peerNode.RelayedBy == node.ID.String() { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(errors.New("node acting as relay for the peer node"), "badrequest"), - ) - return - } - if (node.InternetGwID != "" && failOverNode.IsInternetGateway && node.InternetGwID != failOverNode.ID.String()) || - (peerNode.InternetGwID != "" && failOverNode.IsInternetGateway && peerNode.InternetGwID != failOverNode.ID.String()) { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - errors.New("node using a internet gw by the peer node"), - "badrequest", - ), - ) - return - } - if node.IsInternetGateway && peerNode.InternetGwID == node.ID.String() { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - errors.New("node acting as internet gw for the peer node"), - "badrequest", - ), - ) - return - } - if node.InternetGwID != "" && node.InternetGwID == peerNode.ID.String() { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - errors.New("node using a internet gw by the peer node"), - "badrequest", - ), - ) - return - } - if ok := logic.IsPeerAllowed(node, peerNode, true); !ok { - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - errors.New("peers are not allowed to communicate"), - "badrequest", - ), - ) - return - } - - err = proLogic.CheckFailOverCtx(failOverNode, node, peerNode) - if err != nil { - slog.Error("failover ctx cannot be set ", "error", err) - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(fmt.Errorf("failover ctx cannot be set: %v", err), "internal"), - ) - return - } - - w.Header().Set("Content-Type", "application/json") - logic.ReturnSuccessResponse(w, r, "failover can be set") -} diff --git a/pro/controllers/posture_check.go b/pro/controllers/posture_check.go index dd9b21985..bba0c3330 100644 --- a/pro/controllers/posture_check.go +++ b/pro/controllers/posture_check.go @@ -4,18 +4,21 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "time" "github.com/google/uuid" "github.com/gorilla/mux" "github.com/gravitl/netmaker/db" + dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" proLogic "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" + "gorm.io/gorm" ) func PostureCheckHandlers(r *mux.Router) { @@ -320,15 +323,15 @@ func deletePostureCheck(w http.ResponseWriter, r *http.Request) { // @Failure 500 {object} models.ErrorResponse func listPostureCheckViolatedNodes(w http.ResponseWriter, r *http.Request) { - network := r.URL.Query().Get("network") - if network == "" { + networkName := r.URL.Query().Get("network") + if networkName == "" { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq)) return } listViolatedusers := r.URL.Query().Get("users") == "true" violatedNodes := []models.Node{} if listViolatedusers { - extclients, err := logic.GetNetworkExtClients(network) + extclients, err := logic.GetNetworkExtClients(networkName) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) return @@ -341,16 +344,37 @@ func listPostureCheckViolatedNodes(w http.ResponseWriter, r *http.Request) { } } } else { - nodes, err := logic.GetNetworkNodes(network) + network := &schema.Network{ + Name: networkName, + } + err := network.Get(r.Context()) if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) + errType := logic.Internal + if errors.Is(err, gorm.ErrRecordNotFound) { + errType = logic.BadReq + } + + err = fmt.Errorf("failed to list posture check violated nodes in network %s: error fetching network: %v", networkName, err) + logger.Log(0, r.Header.Get("user"), err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, errType)) return } - for _, node := range nodes { - if len(node.PostureChecksViolations) > 0 { - violatedNodes = append(violatedNodes, node) - } + _nodes, err := (&schema.Node{}).ListAll( + r.Context(), + dbtypes.WithPreloads("Host"), + dbtypes.WithFilter("network_id", network.ID), + dbtypes.WithNotFilter("posture_check_severity", schema.SeverityUnknown), + ) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + + for _, _node := range _nodes { + _node.Network = network + node := logic.ConvertSchemaNodeToModelsNode(&_node) + violatedNodes = append(violatedNodes, *node) } } apiNodes := logic.GetAllNodesAPI(violatedNodes) diff --git a/pro/controllers/users.go b/pro/controllers/users.go index c0f149436..b758f1d12 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -9,13 +9,16 @@ import ( "net/url" "strconv" "strings" + "time" "github.com/gorilla/mux" + "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/orchestrator" proAuth "github.com/gravitl/netmaker/pro/auth" "github.com/gravitl/netmaker/pro/email" "github.com/gravitl/netmaker/pro/idp" @@ -27,6 +30,7 @@ import ( "github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/utils" "golang.org/x/exp/slog" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "gorm.io/datatypes" "gorm.io/gorm" ) @@ -1334,9 +1338,6 @@ func removeUserFromRemoteAccessGW(w http.ResponseWriter, r *http.Request) { } } } - if servercfg.IsDNSMode() { - logic.SetDNS() - } }(user, remoteGwID) err = logic.UpsertUser(*user) @@ -1516,7 +1517,89 @@ func getRemoteAccessGatewayConf(w http.ResponseWriter, r *http.Request) { userConf.Tags = make(map[models.TagID]struct{}) // userConf.Tags[models.TagID(fmt.Sprintf("%s.%s", userConf.Network, // models.RemoteAccessTagName))] = struct{}{} - if err = logic.CreateExtClient(&userConf); err != nil { + if len(userConf.PublicKey) == 0 { + privateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + userConf.PrivateKey = privateKey.String() + userConf.PublicKey = privateKey.PublicKey().String() + } else if len(userConf.PrivateKey) == 0 && len(userConf.PublicKey) > 0 { + userConf.PrivateKey = "[ENTER PRIVATE KEY]" + } + if userConf.ExtraAllowedIPs == nil { + userConf.ExtraAllowedIPs = []string{} + } + + if userConf.Address == "" { + if network.AddressRange != "" { + newAddress, err := orchestrator.GetRepository().NetworkOrchestrator().AllocateExtclientIP(r.Context(), network) + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + userConf.Address = newAddress.String() + } + } + + if userConf.Address6 == "" { + if network.AddressRange6 != "" { + addr6, err := orchestrator.GetRepository().NetworkOrchestrator().AllocateExtclientIPv6(db.WithContext(context.TODO()), network) + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + userConf.Address6 = addr6.String() + } + } + + if userConf.ClientID == "" { + userConf.ClientID, err = logic.GenerateNodeName(userConf.Network) + if err != nil { + slog.Error( + "failed to create extclient", + "user", + r.Header.Get("user"), + "network", + node.Network, + "error", + err, + ) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + } + + userConf.LastModified = time.Now().Unix() + if err = logic.SaveExtClient(&userConf); err != nil { slog.Error( "failed to create extclient", "user", diff --git a/pro/initialize.go b/pro/initialize.go index b0ab56966..f959a5660 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -15,11 +15,13 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" + "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/pro/auth" proControllers "github.com/gravitl/netmaker/pro/controllers" "github.com/gravitl/netmaker/pro/email" "github.com/gravitl/netmaker/pro/license" proLogic "github.com/gravitl/netmaker/pro/logic" + "github.com/gravitl/netmaker/pro/orchestrator/extensions" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" @@ -28,6 +30,7 @@ import ( // InitPro - Initialize Pro Logic func InitPro() { servercfg.IsPro = true + orchestrator.InitializeRepository(extensions.NewProFactory()) models.SetLogo(retrieveProLogo()) controller.HttpMiddlewares = append( controller.HttpMiddlewares, @@ -39,7 +42,6 @@ func InitPro() { controller.HttpHandlers, proControllers.MetricHandlers, proControllers.UserHandlers, - proControllers.FailOverHandlers, proControllers.RacHandlers, proControllers.EventHandlers, proControllers.TagHandlers, @@ -106,7 +108,6 @@ func InitPro() { slog.Error("no OAuth provider found or not configured, continuing without OAuth") } proLogic.LoadNodeMetricsToCache() - proLogic.InitFailOverCache() if servercfg.CacheEnabled() { proLogic.InitAutoRelayCache() } @@ -142,11 +143,6 @@ func InitPro() { go proLogic.EventWatcher() logic.GetMetricsMonitor().Start() }) - logic.ResetFailOver = proLogic.ResetFailOver - logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer - logic.FailOverExists = proLogic.FailOverExists - logic.CreateFailOver = proLogic.CreateFailOver - logic.GetFailOverPeerIps = proLogic.GetFailOverPeerIps logic.ResetAutoRelay = proLogic.ResetAutoRelay logic.ResetAutoRelayedPeer = proLogic.ResetAutoRelayedPeer diff --git a/pro/logic/failover.go b/pro/logic/failover.go deleted file mode 100644 index 5c995d526..000000000 --- a/pro/logic/failover.go +++ /dev/null @@ -1,274 +0,0 @@ -package logic - -import ( - "context" - "errors" - "net" - "sync" - - "github.com/google/uuid" - "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/schema" -) - -var failOverCtxMutex = &sync.RWMutex{} -var failOverCacheMutex = &sync.RWMutex{} -var failOverCache = make(map[schema.NetworkID]string) - -func InitFailOverCache() { - failOverCacheMutex.Lock() - defer failOverCacheMutex.Unlock() - networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) - if err != nil { - return - } - allNodes, err := logic.GetAllNodes() - if err != nil { - return - } - - for _, network := range networks { - networkNodes := logic.GetNetworkNodesMemory(allNodes, network.Name) - for _, node := range networkNodes { - if node.IsFailOver { - failOverCache[schema.NetworkID(network.Name)] = node.ID.String() - break - } - } - } -} - -func CheckFailOverCtx(failOverNode, victimNode, peerNode models.Node) error { - failOverCtxMutex.RLock() - defer failOverCtxMutex.RUnlock() - if peerNode.FailOverPeers == nil { - return nil - } - if victimNode.FailOverPeers == nil { - return nil - } - if peerNode.Mutex != nil { - peerNode.Mutex.Lock() - } - _, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()] - if peerNode.Mutex != nil { - peerNode.Mutex.Unlock() - } - if victimNode.Mutex != nil { - victimNode.Mutex.Lock() - } - _, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()] - if victimNode.Mutex != nil { - victimNode.Mutex.Unlock() - } - if peerHasFailovered && victimHasFailovered && - victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID { - return errors.New("failover ctx is already set") - } - return nil -} -func SetFailOverCtx(failOverNode, victimNode, peerNode models.Node) error { - failOverCtxMutex.Lock() - defer failOverCtxMutex.Unlock() - if peerNode.FailOverPeers == nil { - peerNode.FailOverPeers = make(map[string]struct{}) - } - if victimNode.FailOverPeers == nil { - victimNode.FailOverPeers = make(map[string]struct{}) - } - if peerNode.Mutex != nil { - peerNode.Mutex.Lock() - } - _, peerHasFailovered := peerNode.FailOverPeers[victimNode.ID.String()] - if peerNode.Mutex != nil { - peerNode.Mutex.Unlock() - } - if victimNode.Mutex != nil { - victimNode.Mutex.Lock() - } - _, victimHasFailovered := victimNode.FailOverPeers[peerNode.ID.String()] - if victimNode.Mutex != nil { - victimNode.Mutex.Unlock() - } - if peerHasFailovered && victimHasFailovered && - victimNode.FailedOverBy == failOverNode.ID && peerNode.FailedOverBy == failOverNode.ID { - return errors.New("failover ctx is already set") - } - if peerNode.Mutex != nil { - peerNode.Mutex.Lock() - } - peerNode.FailOverPeers[victimNode.ID.String()] = struct{}{} - if peerNode.Mutex != nil { - peerNode.Mutex.Unlock() - } - if victimNode.Mutex != nil { - victimNode.Mutex.Lock() - } - victimNode.FailOverPeers[peerNode.ID.String()] = struct{}{} - if victimNode.Mutex != nil { - victimNode.Mutex.Unlock() - } - victimNode.FailedOverBy = failOverNode.ID - peerNode.FailedOverBy = failOverNode.ID - if err := logic.UpsertNode(&victimNode); err != nil { - return err - } - if err := logic.UpsertNode(&peerNode); err != nil { - return err - } - return nil -} - -// GetFailOverNode - gets the host acting as failOver -func GetFailOverNode(network string, allNodes []models.Node) (models.Node, error) { - nodes := logic.GetNetworkNodesMemory(allNodes, network) - for _, node := range nodes { - if node.IsFailOver { - return node, nil - } - } - return models.Node{}, errors.New("auto relay not found") -} - -func RemoveFailOverFromCache(network string) { - failOverCacheMutex.Lock() - defer failOverCacheMutex.Unlock() - delete(failOverCache, schema.NetworkID(network)) -} - -func SetFailOverInCache(node models.Node) { - failOverCacheMutex.Lock() - defer failOverCacheMutex.Unlock() - failOverCache[schema.NetworkID(node.Network)] = node.ID.String() -} - -// FailOverExists - checks if failOver exists already in the network -func FailOverExists(network string) (failOverNode models.Node, exists bool) { - failOverCacheMutex.RLock() - defer failOverCacheMutex.RUnlock() - if nodeID, ok := failOverCache[schema.NetworkID(network)]; ok { - failOverNode, err := logic.GetNodeByID(nodeID) - if err == nil { - return failOverNode, true - } - } - return -} - -// ResetFailedOverPeer - removes failed over node from network peers -func ResetFailedOverPeer(failedOveredNode *models.Node) error { - if failedOveredNode.FailedOverBy == uuid.Nil && len(failedOveredNode.FailOverPeers) == 0 { - return nil - } - nodes, err := logic.GetNetworkNodes(failedOveredNode.Network) - if err != nil { - return err - } - failedOveredNode.FailedOverBy = uuid.Nil - failedOveredNode.FailOverPeers = make(map[string]struct{}) - err = logic.UpsertNode(failedOveredNode) - if err != nil { - return err - } - nodeIDStr := failedOveredNode.ID.String() - for _, node := range nodes { - if node.FailOverPeers == nil || node.ID == failedOveredNode.ID { - continue - } - if _, exists := node.FailOverPeers[nodeIDStr]; !exists { - continue - } - delete(node.FailOverPeers, nodeIDStr) - logic.UpsertNode(&node) - } - return nil -} - -// ResetFailOver - reset failovered peers -func ResetFailOver(failOverNode *models.Node) error { - // Unset FailedOverPeers - nodes, err := logic.GetNetworkNodes(failOverNode.Network) - if err != nil { - return err - } - for _, node := range nodes { - if node.FailedOverBy == failOverNode.ID { - node.FailedOverBy = uuid.Nil - node.FailOverPeers = make(map[string]struct{}) - logic.UpsertNode(&node) - } - } - return nil -} - -// GetFailOverPeerIps - adds the failedOvered peerIps by the peer -func GetFailOverPeerIps(peer, node *models.Node) []net.IPNet { - allowedips := []net.IPNet{} - eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO())) - acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network)) - for failOverpeerID := range node.FailOverPeers { - failOverpeer, err := logic.GetNodeByID(failOverpeerID) - if err == nil && failOverpeer.FailedOverBy == peer.ID { - logic.GetNodeEgressInfo(&failOverpeer, eli, acls) - if failOverpeer.Address.IP != nil { - allowed := net.IPNet{ - IP: failOverpeer.Address.IP, - Mask: net.CIDRMask(32, 32), - } - allowedips = append(allowedips, allowed) - } - if failOverpeer.Address6.IP != nil { - allowed := net.IPNet{ - IP: failOverpeer.Address6.IP, - Mask: net.CIDRMask(128, 128), - } - allowedips = append(allowedips, allowed) - } - if failOverpeer.EgressDetails.IsEgressGateway { - allowedips = append(allowedips, logic.GetEgressIPs(&failOverpeer)...) - } - if failOverpeer.IsRelay { - for _, id := range failOverpeer.RelayedNodes { - rNode, _ := logic.GetNodeByID(id) - logic.GetNodeEgressInfo(&rNode, eli, acls) - if rNode.Address.IP != nil { - allowed := net.IPNet{ - IP: rNode.Address.IP, - Mask: net.CIDRMask(32, 32), - } - allowedips = append(allowedips, allowed) - } - if rNode.Address6.IP != nil { - allowed := net.IPNet{ - IP: rNode.Address6.IP, - Mask: net.CIDRMask(128, 128), - } - allowedips = append(allowedips, allowed) - } - if rNode.EgressDetails.IsEgressGateway { - allowedips = append(allowedips, logic.GetEgressIPs(&rNode)...) - } - } - } - // handle ingress gateway peers - if failOverpeer.IsIngressGateway { - extPeers, _, _, err := logic.GetExtPeers(&failOverpeer, node, make(map[string]models.PeerIdentity)) - if err != nil { - logger.Log(2, "could not retrieve ext peers for ", peer.ID.String(), err.Error()) - } - for _, extPeer := range extPeers { - allowedips = append(allowedips, extPeer.AllowedIPs...) - } - } - } - } - return allowedips -} - -func CreateFailOver(node models.Node) error { - - return nil -} diff --git a/pro/logic/status.go b/pro/logic/status.go index 7b7c67209..63ee8d80a 100644 --- a/pro/logic/status.go +++ b/pro/logic/status.go @@ -14,68 +14,68 @@ func getNodeStatusOld(node *models.Node) { // On CE check only last check-in time if node.IsStatic { if !node.StaticNode.Enabled { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } - node.Status = models.OnlineSt + node.Status = schema.OnlineSt return } if !node.Connected { - node.Status = models.Disconnected + node.Status = schema.Disconnected return } if time.Since(node.LastCheckIn) > time.Minute*10 { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } - node.Status = models.OnlineSt + node.Status = schema.OnlineSt } func GetNodeStatus(node *models.Node, defaultEnabledPolicy bool) { if node.IsStatic { if !node.StaticNode.Enabled { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } ingNode, err := logic.GetNodeByID(node.StaticNode.IngressGatewayID) if err != nil { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } if !defaultEnabledPolicy { allowed, _ := logic.IsNodeAllowedToCommunicate(*node, ingNode, false) if !allowed { - node.Status = models.OnlineSt + node.Status = schema.OnlineSt return } } // check extclient connection from metrics ingressMetrics, err := GetMetrics(node.StaticNode.IngressGatewayID) if err != nil || ingressMetrics == nil || ingressMetrics.Connectivity == nil { - node.Status = models.UnKnown + node.Status = schema.UnKnown return } if metric, ok := ingressMetrics.Connectivity[node.StaticNode.ClientID]; ok { if metric.Connected { - node.Status = models.OnlineSt + node.Status = schema.OnlineSt return } else { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } } - node.Status = models.UnKnown + node.Status = schema.UnKnown return } if !node.Connected { - node.Status = models.Disconnected + node.Status = schema.Disconnected return } if time.Since(node.LastCheckIn) > models.LastCheckInThreshold { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } host := &schema.Host{ @@ -83,12 +83,12 @@ func GetNodeStatus(node *models.Node, defaultEnabledPolicy bool) { } err := host.Get(db.WithContext(context.TODO())) if err != nil { - node.Status = models.UnKnown + node.Status = schema.UnKnown return } vlt, err := logic.VersionLessThan(host.Version, "v0.30.0") if err != nil { - node.Status = models.UnKnown + node.Status = schema.UnKnown return } if vlt { @@ -101,11 +101,11 @@ func GetNodeStatus(node *models.Node, defaultEnabledPolicy bool) { } if metrics == nil || metrics.Connectivity == nil || len(metrics.Connectivity) == 0 { if time.Since(node.LastCheckIn) < models.LastCheckInThreshold { - node.Status = models.OnlineSt + node.Status = schema.OnlineSt return } if node.LastCheckIn.IsZero() { - node.Status = models.OfflineSt + node.Status = schema.OfflineSt return } } @@ -156,7 +156,7 @@ func checkPeerStatus(node *models.Node, defaultAclPolicy bool) { } if metrics == nil || metrics.Connectivity == nil { if time.Since(node.LastCheckIn) < models.LastCheckInThreshold { - node.Status = models.OnlineSt + node.Status = schema.OnlineSt return } } @@ -179,21 +179,21 @@ func checkPeerStatus(node *models.Node, defaultAclPolicy bool) { if metric.Connected { continue } - if peer.Status == models.ErrorSt { + if peer.Status == schema.ErrorSt { continue } peerNotConnectedCnt++ } if peerNotConnectedCnt == 0 { - node.Status = models.OnlineSt + node.Status = schema.OnlineSt return } if len(metrics.Connectivity) > 0 && peerNotConnectedCnt == len(metrics.Connectivity) { - node.Status = models.ErrorSt + node.Status = schema.ErrorSt return } - node.Status = models.WarningSt + node.Status = schema.WarningSt } func checkPeerConnectivity(node *models.Node, metrics *models.Metrics, defaultAclPolicy bool) { @@ -219,7 +219,7 @@ func checkPeerConnectivity(node *models.Node, metrics *models.Metrics, defaultAc } // check if peer is in error state checkPeerStatus(&peer, defaultAclPolicy) - if peer.Status == models.ErrorSt || peer.Status == models.WarningSt { + if peer.Status == schema.ErrorSt || peer.Status == schema.WarningSt { continue } peerNotConnectedCnt++ @@ -227,15 +227,15 @@ func checkPeerConnectivity(node *models.Node, metrics *models.Metrics, defaultAc } if peerNotConnectedCnt > len(metrics.Connectivity)/2 { - node.Status = models.WarningSt + node.Status = schema.WarningSt return } if len(metrics.Connectivity) > 0 && peerNotConnectedCnt == len(metrics.Connectivity) { - node.Status = models.ErrorSt + node.Status = schema.ErrorSt return } - node.Status = models.OnlineSt + node.Status = schema.OnlineSt } diff --git a/pro/logic/user_mgmt.go b/pro/logic/user_mgmt.go index 64d03b4f2..0dd21b1c7 100644 --- a/pro/logic/user_mgmt.go +++ b/pro/logic/user_mgmt.go @@ -15,7 +15,6 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" - "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" ) @@ -919,9 +918,6 @@ func UpdatesUserGwAccessOnRoleUpdates(currNetworkAccess, } } - if servercfg.IsDNSMode() { - logic.SetDNS() - } } func UpdatesUserGwAccessOnGrpUpdates(groupID schema.UserGroupID, oldNetworkRoles, newNetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{}) { @@ -986,11 +982,6 @@ func UpdatesUserGwAccessOnGrpUpdates(groupID schema.UserGroupID, oldNetworkRoles } } } - - if servercfg.IsDNSMode() { - logic.SetDNS() - } - } func UpdateUserGwAccess(currentUser, changeUser *schema.User) { @@ -1041,10 +1032,6 @@ func UpdateUserGwAccess(currentUser, changeUser *schema.User) { } } - if servercfg.IsDNSMode() { - logic.SetDNS() - } - } func EnsureDefaultUserGroupNetworkPolicies(old, new *schema.UserGroup) error { diff --git a/pro/orchestrator/extensions/factory.go b/pro/orchestrator/extensions/factory.go new file mode 100644 index 000000000..a1855206c --- /dev/null +++ b/pro/orchestrator/extensions/factory.go @@ -0,0 +1,7 @@ +package extensions + +import "github.com/gravitl/netmaker/orchestrator/extensions" + +func NewProFactory() *extensions.Factory { + return extensions.NewFactory(&ProNodeExtensions{}) +} diff --git a/pro/orchestrator/extensions/node.go b/pro/orchestrator/extensions/node.go new file mode 100644 index 000000000..454b1c5e2 --- /dev/null +++ b/pro/orchestrator/extensions/node.go @@ -0,0 +1,16 @@ +package extensions + +import ( + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" +) + +type ProNodeExtensions struct{} + +func (p *ProNodeExtensions) ConfigureAutoRelay(node *schema.Node) { + node.IsAutoRelay = true +} + +func (p *ProNodeExtensions) ConfigureAutoAssignGateway(node *schema.Node, key *models.EnrollmentKey) { + node.AutoAssignGateway = key.AutoAssignGateway +} diff --git a/schema/hosts.go b/schema/hosts.go index 65936362a..6f2d16f20 100644 --- a/schema/hosts.go +++ b/schema/hosts.go @@ -16,6 +16,15 @@ import ( "gorm.io/datatypes" ) +const ( + // FIREWALL_IPTABLES - indicates that iptables is the firewall in use + FIREWALL_IPTABLES = "iptables" + // FIREWALL_NFTABLES - indicates nftables is in use (Linux only) + FIREWALL_NFTABLES = "nftables" + // FIREWALL_NONE - indicates that no supported firewall in use + FIREWALL_NONE = "none" +) + // Iface struct for local interfaces of a node type Iface struct { Name string `json:"name"` diff --git a/schema/models.go b/schema/models.go index d6e6ecf98..76866f043 100644 --- a/schema/models.go +++ b/schema/models.go @@ -19,5 +19,7 @@ func ListModels() []interface{} { &Host{}, &PendingUser{}, &UserInvite{}, + &Node{}, + &PostureCheckViolation{}, } } diff --git a/schema/nodes.go b/schema/nodes.go new file mode 100644 index 000000000..26210b018 --- /dev/null +++ b/schema/nodes.go @@ -0,0 +1,225 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" + dbtypes "github.com/gravitl/netmaker/db/types" + "gorm.io/datatypes" +) + +const nodesTable = "nodes_v1" + +const ( + // NODE_DELETE - delete node action + NODE_DELETE = "delete" + // NODE_IS_PENDING - node pending status + NODE_IS_PENDING = "pending" + // NODE_NOOP - node no op action + NODE_NOOP = "noop" + // NODE_FORCE_UPDATE - indicates a node should pull all changes + NODE_FORCE_UPDATE = "force" +) + +type NodeStatus string + +const ( + OnlineSt NodeStatus = "online" + OfflineSt NodeStatus = "offline" + WarningSt NodeStatus = "warning" + ErrorSt NodeStatus = "error" + UnKnown NodeStatus = "unknown" + Disconnected NodeStatus = "disconnected" +) + +// TODO: check network and host delete cascade issues. +// TODO: Add gateways list API. +// TODO: Add gateway configs list API. + +type Node struct { + ID string `gorm:"primaryKey" json:"id"` + HostID string `gorm:"not null;index" json:"host_id"` + Host *Host `gorm:"foreignKey:HostID;constraint:OnDelete:CASCADE" json:"host,omitempty"` + NetworkID string `gorm:"not null;index" json:"network_id"` + Network *Network `gorm:"foreignKey:NetworkID;constraint:OnDelete:CASCADE" json:"network,omitempty"` + Address string `json:"address"` + Address6 string `json:"address6"` + Connected bool `json:"connected"` + Action string `json:"action"` + Status NodeStatus `json:"status"` + PendingDelete bool `json:"pending_delete"` + AutoAssignGateway bool `json:"auto_assign_gateway"` + IsGateway bool `json:"is_gateway"` + IsAutoRelay bool `json:"is_auto_relay"` + IsInternetGateway bool `json:"is_internet_gateway"` + RelayedClients datatypes.JSONMap `json:"relayed_clients"` + RelayedIGWClients datatypes.JSONMap `json:"relayed_igw_clients"` + RelayingNodeID *string `json:"relaying_node_id"` + IsIGWClient bool `json:"is_igw_client"` + AutoRelayedPeers datatypes.JSONType[map[string]string] `json:"auto_relayed_peers"` + Tags datatypes.JSONMap `json:"tags"` + PostureCheckSeverity Severity `json:"posture_check_severity"` + PostureCheckLastEvaluationCycleID string `json:"posture_check_last_evaluation_cycle_id"` + Metadata string `json:"metadata"` + LastCheckIn time.Time `json:"last_check_in"` + ExpirationDateTime time.Time `json:"expiration_date_time"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (n *Node) TableName() string { + return nodesTable +} + +func (n *Node) Create(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}).Create(n).Error +} + +func (n *Node) Get(ctx context.Context, options ...dbtypes.Option) error { + query := db.FromContext(ctx).Model(&Node{}) + for _, opt := range options { + query = opt(query) + } + return query.Where("id = ?", n.ID).First(n).Error +} + +func (n *Node) Exists(ctx context.Context) (bool, error) { + var exists bool + err := db.FromContext(ctx).Raw( + "SELECT EXISTS (SELECT 1 FROM nodes_v1 WHERE id = ?)", + n.ID, + ).Scan(&exists).Error + return exists, err +} + +func (n *Node) GetByHostAndNetwork(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}). + Where("host_id = ? AND network_id = ?", n.HostID, n.NetworkID). + First(n). + Error +} + +func (n *Node) GetByNetworkAndAddress(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}). + Where("network_id = ? AND address = ?", n.NetworkID, n.Address). + First(n). + Error +} + +func (n *Node) GetByNetworkAndAddress6(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}). + Where("network_id = ? AND address6 = ?", n.NetworkID, n.Address6). + First(n). + Error +} + +func (n *Node) Update(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}).Where("id = ?", n.ID).Updates(n).Error +} + +func (n *Node) Upsert(ctx context.Context) error { + return db.FromContext(ctx).Save(n).Error +} + +func (n *Node) Delete(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}).Where("id = ?", n.ID).Delete(n).Error +} + +func (n *Node) ListAll(ctx context.Context, options ...dbtypes.Option) ([]Node, error) { + var nodes []Node + query := db.FromContext(ctx).Model(&Node{}) + for _, opt := range options { + query = opt(query) + } + err := query.Find(&nodes).Error + return nodes, err +} + +func (n *Node) Count(ctx context.Context, options ...dbtypes.Option) (int, error) { + var count int64 + query := db.FromContext(ctx).Model(&Node{}) + for _, opt := range options { + query = opt(query) + } + err := query.Count(&count).Error + return int(count), err +} + +func (n *Node) UpsertViolations(ctx context.Context, violations []PostureCheckViolation) error { + if len(violations) > 0 { + err := db.FromContext(ctx).Model(&PostureCheckViolation{}).Create(&violations).Error + if err != nil { + return err + } + } + + return db.FromContext(ctx).Model(&Node{}). + Where("id = ?", n.ID). + Update("posture_check_last_evaluation_cycle_id", n.PostureCheckLastEvaluationCycleID). + Update("posture_check_severity", n.PostureCheckSeverity). + Error +} + +func (n *Node) ListViolations(ctx context.Context) ([]PostureCheckViolation, error) { + var violations []PostureCheckViolation + err := db.FromContext(ctx).Model(&PostureCheckViolation{}). + Where("node_id = ? AND evaluation_cycle_id = ?", n.ID, n.PostureCheckLastEvaluationCycleID). + Find(&violations). + Error + return violations, err +} + +func (n *Node) DeleteViolations(ctx context.Context) error { + return db.FromContext(ctx).Model(&PostureCheckViolation{}). + Where("node_id = ?", n.ID). + Delete(&PostureCheckViolation{}). + Error +} + +func (n *Node) UpdateConnectedStatus(ctx context.Context, options ...dbtypes.Option) error { + query := db.FromContext(ctx).Model(&Node{}) + for _, opt := range options { + query = opt(query) + } + if n.ID != "" { + query = query.Where("id = ?", n.ID) + } + return query.Update("connected", n.Connected).Error +} + +func (n *Node) MarkForDeletion(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}). + Where("id = ?", n.ID). + Update("pending_delete", true). + Update("action", NODE_DELETE). + Error +} + +func (n *Node) UpdateRelayingNode(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}). + Where("id = ?", n.ID). + Update("relaying_node_id", n.RelayingNodeID). + Error +} + +func (n *Node) UpdateRelayedClients(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}). + Where("id = ?", n.ID). + Update("relayed_clients", n.RelayedClients). + Error +} + +func (n *Node) UpdateTags(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}). + Where("id = ?", n.ID). + Update("tags", n.Tags). + Error +} + +func (n *Node) UpdateLastCheckIn(ctx context.Context) error { + return db.FromContext(ctx).Model(&Node{}). + Where("id = ?", n.ID). + Update("last_check_in", n.LastCheckIn). + Error +} diff --git a/schema/posture_check_violations.go b/schema/posture_check_violations.go new file mode 100644 index 000000000..0a4b9dc0d --- /dev/null +++ b/schema/posture_check_violations.go @@ -0,0 +1,22 @@ +package schema + +import ( + "time" +) + +const postureCheckViolationsTable = "posture_check_violations_v1" + +type PostureCheckViolation struct { + EvaluationCycleID string `gorm:"primaryKey" json:"evaluation_cycle_id"` + CheckID string `gorm:"primaryKey" json:"check_id"` + NodeID string `gorm:"primaryKey" json:"node_id"` + Name string `json:"name"` + Attribute string `json:"attribute"` + Message string `json:"message"` + Severity Severity `json:"severity"` + EvaluatedAt time.Time `json:"evaluated_at"` +} + +func (v *PostureCheckViolation) TableName() string { + return postureCheckViolationsTable +}