From 68a8ecee7ad4d9886f52d5b210d3fc88e8dd1e75 Mon Sep 17 00:00:00 2001
From: Kristoffer Dalby <kristoffer@dalby.cc>
Date: Mon, 12 Feb 2024 09:11:17 +0100
Subject: [PATCH] Prepare notify channel before sending first update (#1730)

* create channel before sending first update

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* do not notify on register, wait for connect

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
---
 hscontrol/auth.go   | 27 ++++++++-------------------
 hscontrol/grpcv1.go | 10 ----------
 hscontrol/poll.go   | 38 ++++++++++++++++++++++----------------
 3 files changed, 30 insertions(+), 45 deletions(-)

diff --git a/hscontrol/auth.go b/hscontrol/auth.go
index 3e9557a..ff858dc 100644
--- a/hscontrol/auth.go
+++ b/hscontrol/auth.go
@@ -311,9 +311,6 @@ func (h *Headscale) handleAuthKey(
 
 	nodeKey := registerRequest.NodeKey
 
-	var update types.StateUpdate
-	var mkey key.MachinePublic
-
 	// retrieve node information if it exist
 	// The error is not important, because if it does not
 	// exist, then this is a new node and we will move
@@ -338,9 +335,6 @@ func (h *Headscale) handleAuthKey(
 			return
 		}
 
-		mkey = node.MachineKey
-		update = types.StateUpdateExpire(node.ID, registerRequest.Expiry)
-
 		aclTags := pak.Proto().GetAclTags()
 		if len(aclTags) > 0 {
 			// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
@@ -357,6 +351,14 @@ func (h *Headscale) handleAuthKey(
 				return
 			}
 		}
+
+		mkey := node.MachineKey
+		update := types.StateUpdateExpire(node.ID, registerRequest.Expiry)
+
+		if update.Valid() {
+			ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
+			h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
+		}
 	} else {
 		now := time.Now().UTC()
 
@@ -400,13 +402,6 @@ func (h *Headscale) handleAuthKey(
 
 			return
 		}
-
-		mkey = node.MachineKey
-		update = types.StateUpdate{
-			Type:        types.StatePeerChanged,
-			ChangeNodes: types.Nodes{node},
-			Message:     "called from auth.handleAuthKey",
-		}
 	}
 
 	err = h.db.DB.Transaction(func(tx *gorm.DB) error {
@@ -456,12 +451,6 @@ func (h *Headscale) handleAuthKey(
 		return
 	}
 
-	// TODO(kradalby): if notifying after register make sense.
-	if update.Valid() {
-		ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
-		h.nodeNotifier.NotifyWithIgnore(ctx, update, mkey.String())
-	}
-
 	log.Info().
 		Str("node", registerRequest.Hostinfo.Hostname).
 		Str("ips", strings.Join(node.IPAddresses.StringSlice(), ", ")).
diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go
index c12ba73..47edc78 100644
--- a/hscontrol/grpcv1.go
+++ b/hscontrol/grpcv1.go
@@ -200,16 +200,6 @@ func (api headscaleV1APIServer) RegisterNode(
 		return nil, err
 	}
 
-	stateUpdate := types.StateUpdate{
-		Type:        types.StatePeerChanged,
-		ChangeNodes: types.Nodes{node},
-		Message:     "called from api.RegisterNode",
-	}
-	if stateUpdate.Valid() {
-		ctx := types.NotifyCtx(ctx, "cli-registernode", node.Hostname)
-		api.h.nodeNotifier.NotifyWithIgnore(ctx, stateUpdate, node.MachineKey.String())
-	}
-
 	return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
 }
 
diff --git a/hscontrol/poll.go b/hscontrol/poll.go
index 03f52ed..b7e3dad 100644
--- a/hscontrol/poll.go
+++ b/hscontrol/poll.go
@@ -13,6 +13,7 @@ import (
 	"github.com/rs/zerolog/log"
 	xslices "golang.org/x/exp/slices"
 	"gorm.io/gorm"
+	"tailscale.com/envknob"
 	"tailscale.com/tailcfg"
 )
 
@@ -277,6 +278,25 @@ func (h *Headscale) handlePoll(
 		return
 	}
 
+	// Set up the client stream
+	h.pollNetMapStreamWG.Add(1)
+	defer h.pollNetMapStreamWG.Done()
+
+	// Use a buffered channel in case a node is not fully ready
+	// to receive a message to make sure we dont block the entire
+	// notifier.
+	// 12 is arbitrarily chosen.
+	chanSize := 3
+	if size, ok := envknob.LookupInt("HEADSCALE_TUNING_POLL_QUEUE_SIZE"); ok {
+		chanSize = size
+	}
+	updateChan := make(chan types.StateUpdate, chanSize)
+	defer closeChanWithLog(updateChan, node.Hostname, "updateChan")
+
+	// Register the node's update channel
+	h.nodeNotifier.AddNode(node.MachineKey, updateChan)
+	defer h.nodeNotifier.RemoveNode(node.MachineKey)
+
 	// When a node connects to control, list the peers it has at
 	// that given point, further updates are kept in memory in
 	// the Mapper, which lives for the duration of the polling
@@ -289,8 +309,9 @@ func (h *Headscale) handlePoll(
 		return
 	}
 
+	isConnected := h.nodeNotifier.ConnectedMap()
 	for _, peer := range peers {
-		online := h.nodeNotifier.IsConnected(peer.MachineKey)
+		online := isConnected[peer.MachineKey]
 		peer.IsOnline = &online
 	}
 
@@ -357,21 +378,6 @@ func (h *Headscale) handlePoll(
 		go h.pollFailoverRoutes(logErr, "new node", node)
 	}
 
-	// Set up the client stream
-	h.pollNetMapStreamWG.Add(1)
-	defer h.pollNetMapStreamWG.Done()
-
-	// Use a buffered channel in case a node is not fully ready
-	// to receive a message to make sure we dont block the entire
-	// notifier.
-	// 12 is arbitrarily chosen.
-	updateChan := make(chan types.StateUpdate, 12)
-	defer closeChanWithLog(updateChan, node.Hostname, "updateChan")
-
-	// Register the node's update channel
-	h.nodeNotifier.AddNode(node.MachineKey, updateChan)
-	defer h.nodeNotifier.RemoveNode(node.MachineKey)
-
 	keepAliveTicker := time.NewTicker(keepAliveInterval)
 
 	ctx, cancel := context.WithCancel(context.WithValue(ctx, nodeNameContextKey, node.Hostname))