diff --git a/app.go b/app.go index d83254a..a6c10ec 100644 --- a/app.go +++ b/app.go @@ -4,6 +4,7 @@ import ( "fmt" "log" "os" + "sync" "github.com/gin-gonic/gin" "tailscale.com/tailcfg" @@ -30,6 +31,9 @@ type Headscale struct { dbString string publicKey *wgcfg.Key privateKey *wgcfg.PrivateKey + + pollMu sync.Mutex + clientsPolling map[uint64]chan []byte // this is by all means a hackity hack } // NewHeadscale returns the Headscale app @@ -54,6 +58,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) { if err != nil { return nil, err } + h.clientsPolling = make(map[uint64]chan []byte) return &h, nil } @@ -64,9 +69,6 @@ func (h *Headscale) Serve() error { r.GET("/register", h.RegisterWebAPI) r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id", h.RegistrationHandler) - - // r.LoadHTMLFiles("./frontend/build/index.html") - // r.Use(static.Serve("/", static.LocalFile("./frontend/build", true))) err := r.Run(h.cfg.Addr) return err } diff --git a/handlers.go b/handlers.go index 7fd7004..7b29d9d 100644 --- a/handlers.go +++ b/handlers.go @@ -57,7 +57,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // We do have the updated key! if m.NodeKey == wgcfg.Key(req.NodeKey).HexString() { if m.Registered { - log.Println("Registered and we have the updated key! Lets move to map") + log.Println("Client is registered and we have the current key. All clear to /map") resp.AuthURL = "" respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { @@ -102,85 +102,147 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { log.Println("We dont know anything about the new key. WTF") } +// PollNetMapHandler takes care of /machine/:id/map +// +// This is the busiest endpoint, as it keeps the HTTP long poll that updates +// the clients when something in the network changes. +// +// The clients POST stuff like HostInfo and their Endpoints here, but +// only after their first request (marked with the ReadOnly field). +// +// At this moment the updates are sent in a quite horrendous way, but they kinda work. func (h *Headscale) PollNetMapHandler(c *gin.Context) { body, _ := io.ReadAll(c.Request.Body) mKeyStr := c.Param("id") mKey, err := wgcfg.ParseHexKey(mKeyStr) if err != nil { log.Printf("Cannot parse client key: %s", err) - c.String(http.StatusOK, "Sad!") return } req := tailcfg.MapRequest{} err = decode(body, &req, &mKey, h.privateKey) if err != nil { log.Printf("Cannot decode message: %s", err) - c.String(http.StatusOK, "Very sad!") - // return + return } db, err := h.db() if err != nil { log.Printf("Cannot open DB: %s", err) - c.String(http.StatusInternalServerError, ":(") return } defer db.Close() var m Machine if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() { - log.Printf("Cannot encode message: %s", err) - c.String(http.StatusOK, "Extremely sad!") + log.Printf("Cannot find machine: %s", err) return } - endpoints, _ := json.Marshal(req.Endpoints) hostinfo, _ := json.Marshal(req.Hostinfo) - m.Endpoints = postgres.Jsonb{RawMessage: json.RawMessage(endpoints)} + m.Name = req.Hostinfo.Hostname m.HostInfo = postgres.Jsonb{RawMessage: json.RawMessage(hostinfo)} m.DiscoKey = wgcfg.Key(req.DiscoKey).HexString() now := time.Now().UTC() - m.LastSeen = &now + + // From Tailscale client: + // + // ReadOnly is whether the client just wants to fetch the MapResponse, + // without updating their Endpoints. The Endpoints field will be ignored and + // LastSeen will not be updated and peers will not be notified of changes. + // + // The intended use is for clients to discover the DERP map at start-up + // before their first real endpoint update. + if !req.ReadOnly { + endpoints, _ := json.Marshal(req.Endpoints) + m.Endpoints = postgres.Jsonb{RawMessage: json.RawMessage(endpoints)} + m.LastSeen = &now + } db.Save(&m) db.Close() - chanStream := make(chan []byte, 1) - go func() { - defer close(chanStream) + pollData := make(chan []byte, 1) + update := make(chan []byte, 1) + cancelKeepAlive := make(chan []byte, 1) + defer close(pollData) + defer close(update) + defer close(cancelKeepAlive) + h.pollMu.Lock() + h.clientsPolling[m.ID] = update + h.pollMu.Unlock() - data, err := h.getMapResponse(mKey, req, m) - if err != nil { - c.String(http.StatusInternalServerError, ":(") - return - } + data, err := h.getMapResponse(mKey, req, m) + if err != nil { + c.String(http.StatusInternalServerError, ":(") + return + } - //send initial dump - chanStream <- *data - for { + log.Printf("[%s] sending initial map", m.Name) + pollData <- *data - data, err := h.getMapKeepAliveResponse(mKey, req, m) - if err != nil { - c.String(http.StatusInternalServerError, ":(") - return + // We update our peers if the client is not sending ReadOnly in the MapRequest + // so we don't distribute its initial request (it comes with + // empty endpoints to peers) + if !req.ReadOnly { + peers, _ := h.getPeers(m) + h.pollMu.Lock() + for _, p := range *peers { + log.Printf("[%s] notifying peer %s (%s)", m.Name, p.Name, p.Addresses[0]) + if pUp, ok := h.clientsPolling[uint64(p.ID)]; ok { + pUp <- []byte{} + } else { + log.Printf("[%s] Peer %s does not appear to be polling", m.Name, p.Name) } - chanStream <- *data - // keep the node entertained - time.Sleep(time.Second * 180) - break } + h.pollMu.Unlock() + } + + go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m) - }() c.Stream(func(w io.Writer) bool { - if msg, ok := <-chanStream; ok { - log.Printf("🦀 Sending data to %s: %d bytes", c.Request.RemoteAddr, len(msg)) - w.Write(msg) + select { + case data := <-pollData: + log.Printf("[%s] Sending data (%d bytes)", m.Name, len(data)) + w.Write(data) return true - } else { - log.Printf("🦄 Closing connection to %s", c.Request.RemoteAddr) - c.AbortWithStatus(200) + + case <-update: + log.Printf("[%s] Received a request for update", m.Name) + data, err := h.getMapResponse(mKey, req, m) + if err != nil { + fmt.Printf("[%s] 🤮 Cannot get the poll response: %s", m.Name, err) + } + w.Write(*data) + return true + + case <-c.Request.Context().Done(): + log.Printf("[%s] 😥 The client has closed the connection", m.Name) + h.pollMu.Lock() + cancelKeepAlive <- []byte{} + delete(h.clientsPolling, m.ID) + h.pollMu.Unlock() + return false + } }) +} +func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) { + for { + select { + case <-cancel: + return + + default: + data, err := h.getMapKeepAliveResponse(mKey, req, m) + if err != nil { + log.Printf("Error generating the keep alive msg: %s", err) + return + } + pollData <- *data + time.Sleep(60 * time.Second) + } + } } func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) { @@ -221,7 +283,7 @@ func (h *Headscale) getMapResponse(mKey wgcfg.Key, req tailcfg.MapRequest, m Mac return nil, err } } - + // spew.Dump(resp) // declare the incoming size on the first 4 bytes data := make([]byte, 4) binary.LittleEndian.PutUint32(data, uint32(len(respBody))) @@ -289,6 +351,7 @@ func (h *Headscale) handleNewServer(c *gin.Context, db *gorm.DB, idKey wgcfg.Key MachineKey: idKey.HexString(), NodeKey: wgcfg.Key(req.NodeKey).HexString(), Expiry: &req.Expiry, + Name: req.Hostinfo.Hostname, } if err := db.Create(&mNew).Error; err != nil { log.Printf("Could not create row: %s", err)