diff --git a/app.go b/app.go index e397821..70ea2cd 100644 --- a/app.go +++ b/app.go @@ -141,6 +141,20 @@ func (h *Headscale) expireEphemeralNodesWorker() { } } +// WatchForKVUpdates checks the KV DB table for requests to perform tailnet upgrades +// This is a way to communitate the CLI with the headscale server +func (h *Headscale) watchForKVUpdates(milliSeconds int64) { + ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) + for range ticker.C { + h.watchForKVUpdatesWorker() + } +} + +func (h *Headscale) watchForKVUpdatesWorker() { + h.checkForNamespacesPendingUpdates() + // more functions will come here in the future +} + // Serve launches a GIN server with the Headscale API func (h *Headscale) Serve() error { r := gin.Default() @@ -149,6 +163,9 @@ func (h *Headscale) Serve() error { r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id", h.RegistrationHandler) var err error + + go h.watchForKVUpdates(5000) + if h.cfg.TLSLetsEncryptHostname != "" { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { log.Println("WARNING: listening with TLS but ServerURL does not start with https://") diff --git a/db.go b/db.go index 6a057e1..0630252 100644 --- a/db.go +++ b/db.go @@ -79,6 +79,7 @@ func (h *Headscale) openDB() (*gorm.DB, error) { return db, nil } +// getValue returns the value for the given key in KV func (h *Headscale) getValue(key string) (string, error) { var row KV if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -87,6 +88,7 @@ func (h *Headscale) getValue(key string) (string, error) { return row.Value, nil } +// setValue sets value for the given key in KV func (h *Headscale) setValue(key string, value string) error { kv := KV{ Key: key, diff --git a/machine.go b/machine.go index 6f88e8d..1895e46 100644 --- a/machine.go +++ b/machine.go @@ -200,19 +200,22 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { // DeleteMachine softs deletes a Machine from the database func (h *Headscale) DeleteMachine(m *Machine) error { m.Registered = false + namespaceID := m.NamespaceID h.db.Save(&m) // we mark it as unregistered, just in case if err := h.db.Delete(&m).Error; err != nil { return err } - return nil + + return h.RequestMapUpdates(namespaceID) } // HardDeleteMachine hard deletes a Machine from the database func (h *Headscale) HardDeleteMachine(m *Machine) error { + namespaceID := m.NamespaceID if err := h.db.Unscoped().Delete(&m).Error; err != nil { return err } - return nil + return h.RequestMapUpdates(namespaceID) } // GetHostInfo returns a Hostinfo struct for the machine diff --git a/machine_test.go b/machine_test.go index 1bd29a9..d535be5 100644 --- a/machine_test.go +++ b/machine_test.go @@ -1,6 +1,8 @@ package headscale import ( + "encoding/json" + "gopkg.in/check.v1" ) @@ -81,6 +83,15 @@ func (s *Suite) TestDeleteMachine(c *check.C) { h.db.Save(&m) err = h.DeleteMachine(&m) c.Assert(err, check.IsNil) + v, err := h.getValue("namespaces_pending_updates") + c.Assert(err, check.IsNil) + names := []string{} + err = json.Unmarshal([]byte(v), &names) + c.Assert(err, check.IsNil) + c.Assert(names, check.DeepEquals, []string{n.Name}) + h.checkForNamespacesPendingUpdates() + v, _ = h.getValue("namespaces_pending_updates") + c.Assert(v, check.Equals, "") _, err = h.GetMachine(n.Name, "testmachine") c.Assert(err, check.NotNil) } diff --git a/namespaces.go b/namespaces.go index 9897640..840f872 100644 --- a/namespaces.go +++ b/namespaces.go @@ -1,7 +1,9 @@ package headscale import ( + "encoding/json" "errors" + "fmt" "log" "time" @@ -103,6 +105,88 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error return nil } +// RequestMapUpdates signals the KV worker to update the maps for this namespace +func (h *Headscale) RequestMapUpdates(namespaceID uint) error { + namespace := Namespace{} + if err := h.db.First(&namespace, namespaceID).Error; err != nil { + return err + } + + v, err := h.getValue("namespaces_pending_updates") + if err != nil || v == "" { + err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name)) + if err != nil { + return err + } + return nil + } + names := []string{} + err = json.Unmarshal([]byte(v), &names) + if err != nil { + err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name)) + if err != nil { + return err + } + return nil + } + + names = append(names, namespace.Name) + data, err := json.Marshal(names) + if err != nil { + log.Printf("Could not marshal namespaces_pending_updates: %s", err) + return err + } + return h.setValue("namespaces_pending_updates", string(data)) +} + +func (h *Headscale) checkForNamespacesPendingUpdates() { + v, err := h.getValue("namespaces_pending_updates") + if err != nil { + return + } + if v == "" { + return + } + + names := []string{} + err = json.Unmarshal([]byte(v), &names) + if err != nil { + return + } + for _, name := range names { + log.Printf("Sending updates to nodes in namespace %s", name) + machines, err := h.ListMachinesInNamespace(name) + if err != nil { + continue + } + for _, m := range *machines { + peers, _ := h.getPeers(m) + h.pollMu.Lock() + for _, p := range *peers { + pUp, ok := h.clientsPolling[uint64(p.ID)] + if ok { + log.Printf("[%s] Notifying peer %s (%s)", m.Name, p.Name, p.Addresses[0]) + pUp <- []byte{} + } else { + log.Printf("[%s] Peer %s does not appear to be polling", m.Name, p.Name) + } + } + h.pollMu.Unlock() + } + } + newV, err := h.getValue("namespaces_pending_updates") + if err != nil { + return + } + if v == newV { // only clear when no changes, so we notified everybody + err = h.setValue("namespaces_pending_updates", "") + if err != nil { + log.Printf("Could not save to KV: %s", err) + return + } + } +} + func (n *Namespace) toUser() *tailcfg.User { u := tailcfg.User{ ID: tailcfg.UserID(n.ID),