diff --git a/.github/workflows/test-integration.yml b/.github/workflows/test-integration.yml new file mode 100644 index 0000000..e939df2 --- /dev/null +++ b/.github/workflows/test-integration.yml @@ -0,0 +1,23 @@ +name: CI + +on: [pull_request] + +jobs: + # The "build" workflow + integration-test: + # The type of runner that the job will run on + runs-on: ubuntu-latest + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + + # Setup Go + - name: Setup Go + uses: actions/setup-go@v2 + with: + go-version: "1.16.3" + + - name: Run Integration tests + run: go test -tags integration -timeout 30m diff --git a/.gitignore b/.gitignore index 3a64648..44bec69 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ config.json *.key /db.sqlite *.sqlite3 + +test_output/ diff --git a/api.go b/api.go index 0dc2bec..7a6b4b1 100644 --- a/api.go +++ b/api.go @@ -13,7 +13,6 @@ import ( "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" - "gorm.io/datatypes" "gorm.io/gorm" "inet.af/netaddr" "tailscale.com/tailcfg" @@ -82,14 +81,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { return } + now := time.Now().UTC() var m Machine if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") m = Machine{ - Expiry: &req.Expiry, - MachineKey: mKey.HexString(), - Name: req.Hostinfo.Hostname, - NodeKey: wgkey.Key(req.NodeKey).HexString(), + Expiry: &req.Expiry, + MachineKey: mKey.HexString(), + Name: req.Hostinfo.Hostname, + NodeKey: wgkey.Key(req.NodeKey).HexString(), + LastSuccessfulUpdate: &now, } if err := h.db.Create(&m).Error; err != nil { log.Error(). @@ -215,196 +216,6 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { c.Data(200, "application/json; charset=utf-8", respBody) } -// PollNetMapHandler takes care of /machine/:id/map -// -// This is the busiest endpoint, as it keeps the HTTP long poll that updates -// the clients when something in the network changes. -// -// The clients POST stuff like HostInfo and their Endpoints here, but -// only after their first request (marked with the ReadOnly field). -// -// At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (h *Headscale) PollNetMapHandler(c *gin.Context) { - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Msg("PollNetMapHandler called") - body, _ := io.ReadAll(c.Request.Body) - mKeyStr := c.Param("id") - mKey, err := wgkey.ParseHex(mKeyStr) - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Err(err). - Msg("Cannot parse client key") - c.String(http.StatusBadRequest, "") - return - } - req := tailcfg.MapRequest{} - err = decode(body, &req, &mKey, h.privateKey) - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Err(err). - Msg("Cannot decode message") - c.String(http.StatusBadRequest, "") - return - } - - var m Machine - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { - log.Warn(). - Str("handler", "PollNetMap"). - Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) - c.String(http.StatusUnauthorized, "") - return - } - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Found machine in database") - - hostinfo, _ := json.Marshal(req.Hostinfo) - m.Name = req.Hostinfo.Hostname - m.HostInfo = datatypes.JSON(hostinfo) - m.DiscoKey = wgkey.Key(req.DiscoKey).HexString() - now := time.Now().UTC() - - // From Tailscale client: - // - // ReadOnly is whether the client just wants to fetch the MapResponse, - // without updating their Endpoints. The Endpoints field will be ignored and - // LastSeen will not be updated and peers will not be notified of changes. - // - // The intended use is for clients to discover the DERP map at start-up - // before their first real endpoint update. - if !req.ReadOnly { - endpoints, _ := json.Marshal(req.Endpoints) - m.Endpoints = datatypes.JSON(endpoints) - m.LastSeen = &now - } - h.db.Save(&m) - - data, err := h.getMapResponse(mKey, req, m) - if err != nil { - log.Error(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Err(err). - Msg("Failed to get Map response") - c.String(http.StatusInternalServerError, ":(") - return - } - - // We update our peers if the client is not sending ReadOnly in the MapRequest - // so we don't distribute its initial request (it comes with - // empty endpoints to peers) - - // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 - log.Debug(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Bool("readOnly", req.ReadOnly). - Bool("omitPeers", req.OmitPeers). - Bool("stream", req.Stream). - Msg("Client map request processed") - - if req.ReadOnly { - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Client is starting up. Asking for DERP map") - c.Data(200, "application/json; charset=utf-8", *data) - return - } - if req.OmitPeers && !req.Stream { - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Client sent endpoint update and is ok with a response without peer list") - c.Data(200, "application/json; charset=utf-8", *data) - return - } else if req.OmitPeers && req.Stream { - log.Warn(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Ignoring request, don't know how to handle it") - c.String(http.StatusBadRequest, "") - return - } - - // Only create update channel if it has not been created - var update chan []byte - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Creating or loading update channel") - if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok { - update = result.(chan []byte) - } - - pollData := make(chan []byte, 1) - defer close(pollData) - - cancelKeepAlive := make(chan []byte, 1) - defer close(cancelKeepAlive) - - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Client is ready to access the tailnet") - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Sending initial map") - pollData <- *data - - log.Info(). - Str("handler", "PollNetMap"). - Str("machine", m.Name). - Msg("Notifying peers") - // TODO: Why does this block? - go h.notifyChangesToPeers(&m) - - h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive) - log.Trace(). - Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). - Msg("Finished stream, closing PollNetMap session") -} - -func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) { - for { - select { - case <-cancel: - return - - default: - data, err := h.getMapKeepAliveResponse(mKey, req, m) - if err != nil { - log.Error(). - Str("func", "keepAlive"). - Err(err). - Msg("Error generating the keep alive msg") - return - } - - log.Debug(). - Str("func", "keepAlive"). - Str("machine", m.Name). - Msg("Sending keepalive") - pollData <- *data - - time.Sleep(60 * time.Second) - } - } -} - func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) { log.Trace(). Str("func", "getMapResponse"). @@ -542,7 +353,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, Str("func", "handleAuthKey"). Str("machine", m.Name). Str("ip", ip.String()). - Msgf("Assining %s to %s", ip, m.Name) + Msgf("Assigning %s to %s", ip, m.Name) m.AuthKeyID = uint(pak.ID) m.IPAddress = ip.String() diff --git a/app.go b/app.go index fcf287f..e5f4410 100644 --- a/app.go +++ b/app.go @@ -58,7 +58,10 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules *[]tailcfg.FilterRule - clientsPolling sync.Map + clientsUpdateChannels sync.Map + clientsUpdateChannelMutex sync.Mutex + + lastStateChange sync.Map } // NewHeadscale returns the Headscale app @@ -165,9 +168,18 @@ func (h *Headscale) Serve() error { r.POST("/machine/:id", h.RegistrationHandler) var err error + timeout := 30 * time.Second + go h.watchForKVUpdates(5000) go h.expireEphemeralNodes(5000) + s := &http.Server{ + Addr: h.cfg.Addr, + Handler: r, + ReadTimeout: timeout, + WriteTimeout: timeout, + } + if h.cfg.TLSLetsEncryptHostname != "" { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") @@ -179,9 +191,11 @@ func (h *Headscale) Serve() error { Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir), } s := &http.Server{ - Addr: h.cfg.Addr, - TLSConfig: m.TLSConfig(), - Handler: r, + Addr: h.cfg.Addr, + TLSConfig: m.TLSConfig(), + Handler: r, + ReadTimeout: timeout, + WriteTimeout: timeout, } if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" { // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) @@ -206,12 +220,29 @@ func (h *Headscale) Serve() error { if !strings.HasPrefix(h.cfg.ServerURL, "http://") { log.Warn().Msg("Listening without TLS but ServerURL does not start with http://") } - err = r.Run(h.cfg.Addr) + err = s.ListenAndServe() } else { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") } - err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath) + err = s.ListenAndServeTLS(h.cfg.TLSCertPath, h.cfg.TLSKeyPath) } return err } + +func (h *Headscale) setLastStateChangeToNow(namespace string) { + now := time.Now().UTC() + h.lastStateChange.Store(namespace, now) +} + +func (h *Headscale) getLastStateChange(namespace string) time.Time { + if wrapped, ok := h.lastStateChange.Load(namespace); ok { + lastChange, _ := wrapped.(time.Time) + return lastChange + + } + + now := time.Now().UTC() + h.lastStateChange.Store(namespace, now) + return now +} diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 4ada640..7e7e8f9 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -39,7 +39,7 @@ func LoadConfig(path string) error { viper.SetDefault("ip_prefix", "100.64.0.0/10") - viper.SetDefault("log_level", "debug") + viper.SetDefault("log_level", "info") err := viper.ReadInConfig() if err != nil { diff --git a/integration_test.go b/integration_test.go index dd96fb8..8cdc191 100644 --- a/integration_test.go +++ b/integration_test.go @@ -4,10 +4,13 @@ package headscale import ( "bytes" + "context" "fmt" + "io/ioutil" "log" "net/http" "os" + "path" "strings" "testing" "time" @@ -20,23 +23,48 @@ import ( "inet.af/netaddr" ) -type IntegrationTestSuite struct { - suite.Suite -} - -func TestIntegrationTestSuite(t *testing.T) { - suite.Run(t, new(IntegrationTestSuite)) -} - var integrationTmpDir string var ih Headscale var pool dockertest.Pool var network dockertest.Network var headscale dockertest.Resource -var tailscaleCount int = 5 +var tailscaleCount int = 25 var tailscales map[string]dockertest.Resource +type IntegrationTestSuite struct { + suite.Suite + stats *suite.SuiteInformation +} + +func TestIntegrationTestSuite(t *testing.T) { + s := new(IntegrationTestSuite) + suite.Run(t, s) + + // HandleStats, which allows us to check if we passed and save logs + // is called after TearDown, so we cannot tear down containers before + // we have potentially saved the logs. + for _, tailscale := range tailscales { + if err := pool.Purge(&tailscale); err != nil { + log.Printf("Could not purge resource: %s\n", err) + } + } + + if !s.stats.Passed() { + err := saveLog(&headscale, "test_output") + if err != nil { + log.Printf("Could not save log: %s\n", err) + } + } + if err := pool.Purge(&headscale); err != nil { + log.Printf("Could not purge resource: %s\n", err) + } + + if err := network.Close(); err != nil { + log.Printf("Could not close network: %s\n", err) + } +} + func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) { var stdout bytes.Buffer var stderr bytes.Buffer @@ -62,6 +90,48 @@ func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) return stdout.String(), nil } +func saveLog(resource *dockertest.Resource, basePath string) error { + err := os.MkdirAll(basePath, os.ModePerm) + if err != nil { + return err + } + + var stdout bytes.Buffer + var stderr bytes.Buffer + + err = pool.Client.Logs( + docker.LogsOptions{ + Context: context.TODO(), + Container: resource.Container.ID, + OutputStream: &stdout, + ErrorStream: &stderr, + Tail: "all", + RawTerminal: false, + Stdout: true, + Stderr: true, + Follow: false, + Timestamps: false, + }, + ) + if err != nil { + return err + } + + fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath) + + err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0644) + if err != nil { + return err + } + + err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0644) + if err != nil { + return err + } + + return nil +} + func dockerRestartPolicy(config *docker.HostConfig) { // set AutoRemove to true so that stopped container goes away by itself config.AutoRemove = true @@ -115,7 +185,6 @@ func (s *IntegrationTestSuite) SetupSuite() { PortBindings: map[docker.Port][]docker.PortBinding{ "8080/tcp": []docker.PortBinding{{HostPort: "8080"}}, }, - Env: []string{}, } fmt.Println("Creating headscale container") @@ -134,7 +203,6 @@ func (s *IntegrationTestSuite) SetupSuite() { Name: hostname, Networks: []*dockertest.Network{&network}, Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"}, - Env: []string{}, } if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil { @@ -145,7 +213,6 @@ func (s *IntegrationTestSuite) SetupSuite() { fmt.Printf("Created %s container\n", hostname) } - // TODO: Replace this logic with something that can be detected on Github Actions fmt.Println("Waiting for headscale to be ready") hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp")) @@ -197,23 +264,14 @@ func (s *IntegrationTestSuite) SetupSuite() { // The nodes need a bit of time to get their updated maps from headscale // TODO: See if we can have a more deterministic wait here. - time.Sleep(20 * time.Second) + time.Sleep(60 * time.Second) } func (s *IntegrationTestSuite) TearDownSuite() { - if err := pool.Purge(&headscale); err != nil { - log.Printf("Could not purge resource: %s\n", err) - } +} - for _, tailscale := range tailscales { - if err := pool.Purge(&tailscale); err != nil { - log.Printf("Could not purge resource: %s\n", err) - } - } - - if err := network.Close(); err != nil { - log.Printf("Could not close network: %s\n", err) - } +func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) { + s.stats = stats } func (s *IntegrationTestSuite) TestListNodes() { @@ -295,7 +353,15 @@ func (s *IntegrationTestSuite) TestPingAllPeers() { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { // We currently cant ping ourselves, so skip that. if peername != hostname { - command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()} + // We are only interested in "direct ping" which means what we + // might need a couple of more attempts before reaching the node. + command := []string{ + "tailscale", "ping", + "--timeout=1s", + "--c=20", + "--until-direct=true", + ip.String(), + } fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) result, err := executeCommand( diff --git a/integration_test/etc/config.json b/integration_test/etc/config.json index 5454f2f..8a6fd96 100644 --- a/integration_test/etc/config.json +++ b/integration_test/etc/config.json @@ -7,5 +7,5 @@ "db_type": "sqlite3", "db_path": "/tmp/integration_test_db.sqlite3", "acl_policy_path": "", - "log_level": "trace" + "log_level": "debug" } diff --git a/machine.go b/machine.go index 69de453..4cdadd9 100644 --- a/machine.go +++ b/machine.go @@ -2,6 +2,7 @@ package headscale import ( "encoding/json" + "errors" "fmt" "sort" "strconv" @@ -31,8 +32,9 @@ type Machine struct { AuthKeyID uint AuthKey *PreAuthKey - LastSeen *time.Time - Expiry *time.Time + LastSeen *time.Time + LastSuccessfulUpdate *time.Time + Expiry *time.Time HostInfo datatypes.JSON Endpoints datatypes.JSON @@ -211,6 +213,15 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { return &m, nil } +// UpdateMachine takes a Machine struct pointer (typically already loaded from database +// and updates it with the latest data from the database. +func (h *Headscale) UpdateMachine(m *Machine) error { + if result := h.db.Find(m).First(&m); result.Error != nil { + return result.Error + } + return nil +} + // DeleteMachine softs deletes a Machine from the database func (h *Headscale) DeleteMachine(m *Machine) error { m.Registered = false @@ -251,21 +262,110 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { func (h *Headscale) notifyChangesToPeers(m *Machine) { peers, _ := h.getPeers(*m) for _, p := range *peers { - pUp, ok := h.clientsPolling.Load(uint64(p.ID)) - if ok { + log.Info(). + Str("func", "notifyChangesToPeers"). + Str("machine", m.Name). + Str("peer", p.Name). + Str("address", p.Addresses[0].String()). + Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0]) + err := h.sendRequestOnUpdateChannel(p) + if err != nil { log.Info(). Str("func", "notifyChangesToPeers"). Str("machine", m.Name). - Str("peer", m.Name). - Str("address", p.Addresses[0].String()). - Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0]) - pUp.(chan []byte) <- []byte{} - } else { - log.Info(). - Str("func", "notifyChangesToPeers"). - Str("machine", m.Name). - Str("peer", m.Name). + Str("peer", p.Name). Msgf("Peer %s does not appear to be polling", p.Name) } + log.Trace(). + Str("func", "notifyChangesToPeers"). + Str("machine", m.Name). + Str("peer", p.Name). + Str("address", p.Addresses[0].String()). + Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0]) } } + +func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} { + var updateChan chan struct{} + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + updateChan = unwrapped + } else { + log.Error(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Failed to convert update channel to struct{}") + } + } else { + log.Debug(). + Str("handler", "openUpdateChannel"). + Str("machine", m.Name). + Msg("Update channel not found, creating") + + updateChan = make(chan struct{}) + h.clientsUpdateChannels.Store(m.ID, updateChan) + } + return updateChan +} + +func (h *Headscale) closeUpdateChannel(m *Machine) { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + + if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok { + if unwrapped, ok := storedChan.(chan struct{}); ok { + close(unwrapped) + } + } + h.clientsUpdateChannels.Delete(m.ID) +} + +func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error { + h.clientsUpdateChannelMutex.Lock() + defer h.clientsUpdateChannelMutex.Unlock() + + pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID)) + if ok { + log.Info(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Notifying peer %s", m.Name) + + if update, ok := pUp.(chan struct{}); ok { + log.Trace(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Update channel is %#v", update) + + update <- struct{}{} + + log.Trace(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Notified machine %s", m.Name) + } + } else { + log.Info(). + Str("func", "requestUpdate"). + Str("machine", m.Name). + Msgf("Machine %s does not appear to be polling", m.Name) + return errors.New("machine does not seem to be polling") + } + return nil +} + +func (h *Headscale) isOutdated(m *Machine) bool { + err := h.UpdateMachine(m) + if err != nil { + return true + } + + lastChange := h.getLastStateChange(m.Namespace.Name) + log.Trace(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", lastChange). + Msgf("Checking if %s is missing updates", m.Name) + return m.LastSuccessfulUpdate.Before(lastChange) +} diff --git a/poll.go b/poll.go index f0bfe70..bea1616 100644 --- a/poll.go +++ b/poll.go @@ -1,38 +1,225 @@ package headscale import ( + "encoding/json" + "errors" "io" + "net/http" "time" "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" + "gorm.io/datatypes" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) +// PollNetMapHandler takes care of /machine/:id/map +// +// This is the busiest endpoint, as it keeps the HTTP long poll that updates +// the clients when something in the network changes. +// +// The clients POST stuff like HostInfo and their Endpoints here, but +// only after their first request (marked with the ReadOnly field). +// +// At this moment the updates are sent in a quite horrendous way, but they kinda work. +func (h *Headscale) PollNetMapHandler(c *gin.Context) { + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Msg("PollNetMapHandler called") + body, _ := io.ReadAll(c.Request.Body) + mKeyStr := c.Param("id") + mKey, err := wgkey.ParseHex(mKeyStr) + if err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Err(err). + Msg("Cannot parse client key") + c.String(http.StatusBadRequest, "") + return + } + req := tailcfg.MapRequest{} + err = decode(body, &req, &mKey, h.privateKey) + if err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Err(err). + Msg("Cannot decode message") + c.String(http.StatusBadRequest, "") + return + } + + var m Machine + if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { + log.Warn(). + Str("handler", "PollNetMap"). + Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) + c.String(http.StatusUnauthorized, "") + return + } + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Msg("Found machine in database") + + hostinfo, _ := json.Marshal(req.Hostinfo) + m.Name = req.Hostinfo.Hostname + m.HostInfo = datatypes.JSON(hostinfo) + m.DiscoKey = wgkey.Key(req.DiscoKey).HexString() + now := time.Now().UTC() + + // From Tailscale client: + // + // ReadOnly is whether the client just wants to fetch the MapResponse, + // without updating their Endpoints. The Endpoints field will be ignored and + // LastSeen will not be updated and peers will not be notified of changes. + // + // The intended use is for clients to discover the DERP map at start-up + // before their first real endpoint update. + if !req.ReadOnly { + endpoints, _ := json.Marshal(req.Endpoints) + m.Endpoints = datatypes.JSON(endpoints) + m.LastSeen = &now + } + h.db.Save(&m) + + data, err := h.getMapResponse(mKey, req, m) + if err != nil { + log.Error(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Err(err). + Msg("Failed to get Map response") + c.String(http.StatusInternalServerError, ":(") + return + } + + // We update our peers if the client is not sending ReadOnly in the MapRequest + // so we don't distribute its initial request (it comes with + // empty endpoints to peers) + + // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 + log.Debug(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Bool("readOnly", req.ReadOnly). + Bool("omitPeers", req.OmitPeers). + Bool("stream", req.Stream). + Msg("Client map request processed") + + if req.ReadOnly { + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Client is starting up. Probably interested in a DERP map") + c.Data(200, "application/json; charset=utf-8", *data) + return + } + + // There has been an update to _any_ of the nodes that the other nodes would + // need to know about + h.setLastStateChangeToNow(m.Namespace.Name) + + // The request is not ReadOnly, so we need to set up channels for updating + // peers via longpoll + + // Only create update channel if it has not been created + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Msg("Loading or creating update channel") + updateChan := h.getOrOpenUpdateChannel(&m) + + pollDataChan := make(chan []byte) + // defer close(pollData) + + keepAliveChan := make(chan []byte) + + cancelKeepAlive := make(chan struct{}) + defer close(cancelKeepAlive) + + if req.OmitPeers && !req.Stream { + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Client sent endpoint update and is ok with a response without peer list") + c.Data(200, "application/json; charset=utf-8", *data) + + // It sounds like we should update the nodes when we have received a endpoint update + // even tho the comments in the tailscale code dont explicitly say so. + go h.notifyChangesToPeers(&m) + return + } else if req.OmitPeers && req.Stream { + log.Warn(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Ignoring request, don't know how to handle it") + c.String(http.StatusBadRequest, "") + return + } + + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Client is ready to access the tailnet") + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Sending initial map") + go func() { pollDataChan <- *data }() + + log.Info(). + Str("handler", "PollNetMap"). + Str("machine", m.Name). + Msg("Notifying peers") + go h.notifyChangesToPeers(&m) + + h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive) + log.Trace(). + Str("handler", "PollNetMap"). + Str("id", c.Param("id")). + Str("machine", m.Name). + Msg("Finished stream, closing PollNetMap session") +} + +// PollNetMapStream takes care of /machine/:id/map +// stream logic, ensuring we communicate updates and data +// to the connected clients. func (h *Headscale) PollNetMapStream( c *gin.Context, m Machine, req tailcfg.MapRequest, mKey wgkey.Key, - pollData chan []byte, - update chan []byte, - cancelKeepAlive chan []byte, + pollDataChan chan []byte, + keepAliveChan chan []byte, + updateChan <-chan struct{}, + cancelKeepAlive chan struct{}, ) { - - go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m) + go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m) c.Stream(func(w io.Writer) bool { log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). Msg("Waiting for data to stream...") - select { - case data := <-pollData: + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) + + select { + case data := <-pollDataChan: log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "pollData"). Int("bytes", len(data)). Msg("Sending data received via pollData channel") _, err := w.Write(data) @@ -40,44 +227,148 @@ func (h *Headscale) PollNetMapStream( log.Error(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "pollData"). Err(err). Msg("Cannot write data") } log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "pollData"). Int("bytes", len(data)). Msg("Data from pollData channel written successfully") + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. + err = h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "pollData"). + Err(err). + Msg("Cannot update machine from database") + } + now := time.Now().UTC() + m.LastSeen = &now + m.LastSuccessfulUpdate = &now + h.db.Save(&m) + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "pollData"). + Int("bytes", len(data)). + Msg("Machine updated successfully after sending pollData") + return true + + case data := <-keepAliveChan: + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "keepAlive"). + Int("bytes", len(data)). + Msg("Sending keep alive message") + _, err := w.Write(data) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "keepAlive"). + Err(err). + Msg("Cannot write keep alive message") + } + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "keepAlive"). + Int("bytes", len(data)). + Msg("Keep alive sent successfully") + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. + err = h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "keepAlive"). + Err(err). + Msg("Cannot update machine from database") + } now := time.Now().UTC() m.LastSeen = &now h.db.Save(&m) log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "keepAlive"). Int("bytes", len(data)). - Msg("Machine updated successfully after sending pollData") + Msg("Machine updated successfully after sending keep alive") return true - case <-update: - log.Debug(). + case <-updateChan: + log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). + Str("channel", "update"). Msg("Received a request for update") - data, err := h.getMapResponse(mKey, req, m) - if err != nil { - log.Error(). + if h.isOutdated(&m) { + log.Debug(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). - Err(err). - Msg("Could not get the map update") - } - _, err = w.Write(*data) - if err != nil { - log.Error(). + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). + Msgf("There has been updates since the last successful update to %s", m.Name) + data, err := h.getMapResponse(mKey, req, m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "update"). + Err(err). + Msg("Could not get the map update") + } + _, err = w.Write(*data) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "update"). + Err(err). + Msg("Could not write the map response") + } + log.Trace(). Str("handler", "PollNetMapStream"). Str("machine", m.Name). - Err(err). - Msg("Could not write the map response") + Str("channel", "update"). + Msg("Updated Map has been sent") + + // Keep track of the last successful update, + // we sometimes end in a state were the update + // is not picked up by a client and we use this + // to determine if we should "force" an update. + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. + err = h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "update"). + Err(err). + Msg("Cannot update machine from database") + } + now := time.Now().UTC() + m.LastSuccessfulUpdate = &now + h.db.Save(&m) + } else { + log.Trace(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Time("last_successful_update", *m.LastSuccessfulUpdate). + Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). + Msgf("%s is up to date", m.Name) } return true @@ -86,13 +377,78 @@ func (h *Headscale) PollNetMapStream( Str("handler", "PollNetMapStream"). Str("machine", m.Name). Msg("The client has closed the connection") + // TODO: Abstract away all the database calls, this can cause race conditions + // when an outdated machine object is kept alive, e.g. db is update from + // command line, but then overwritten. + err := h.UpdateMachine(&m) + if err != nil { + log.Error(). + Str("handler", "PollNetMapStream"). + Str("machine", m.Name). + Str("channel", "Done"). + Err(err). + Msg("Cannot update machine from database") + } now := time.Now().UTC() m.LastSeen = &now h.db.Save(&m) - cancelKeepAlive <- []byte{} - h.clientsPolling.Delete(m.ID) - close(update) + + cancelKeepAlive <- struct{}{} + + h.closeUpdateChannel(&m) + + close(pollDataChan) + + close(keepAliveChan) + return false } }) } + +func (h *Headscale) scheduledPollWorker( + cancelChan <-chan struct{}, + keepAliveChan chan<- []byte, + mKey wgkey.Key, + req tailcfg.MapRequest, + m Machine, +) { + keepAliveTicker := time.NewTicker(60 * time.Second) + updateCheckerTicker := time.NewTicker(30 * time.Second) + + for { + select { + case <-cancelChan: + return + + case <-keepAliveTicker.C: + data, err := h.getMapKeepAliveResponse(mKey, req, m) + if err != nil { + log.Error(). + Str("func", "keepAlive"). + Err(err). + Msg("Error generating the keep alive msg") + return + } + + log.Debug(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Msg("Sending keepalive") + keepAliveChan <- *data + + case <-updateCheckerTicker.C: + // Send an update request regardless of outdated or not, if data is sent + // to the node is determined in the updateChan consumer block + n, _ := m.toNode() + err := h.sendRequestOnUpdateChannel(n) + if err != nil { + log.Error(). + Str("func", "keepAlive"). + Str("machine", m.Name). + Err(err). + Msgf("Failed to send update request to %s", m.Name) + } + } + } +}