Merge pull request #84 from kradalby/integration-tests-ci

Improve logic to keep nodes up to date with the network state
This commit is contained in:
Juan Font 2021-08-23 09:42:07 +02:00 committed by GitHub
commit 74d2fe1baa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 656 additions and 267 deletions

23
.github/workflows/test-integration.yml vendored Normal file
View file

@ -0,0 +1,23 @@
name: CI
on: [pull_request]
jobs:
# The "build" workflow
integration-test:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3"
- name: Run Integration tests
run: go test -tags integration -timeout 30m

2
.gitignore vendored
View file

@ -19,3 +19,5 @@ config.json
*.key
/db.sqlite
*.sqlite3
test_output/

195
api.go
View file

@ -13,7 +13,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg"
@ -82,6 +81,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return
}
now := time.Now().UTC()
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
@ -90,6 +90,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(),
LastSuccessfulUpdate: &now,
}
if err := h.db.Create(&m).Error; err != nil {
log.Error().
@ -215,196 +216,6 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
c.Data(200, "application/json; charset=utf-8", respBody)
}
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
return
}
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
h.db.Save(&m)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Asking for DERP map")
c.Data(200, "application/json; charset=utf-8", *data)
return
}
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data)
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
return
}
// Only create update channel if it has not been created
var update chan []byte
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Creating or loading update channel")
if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok {
update = result.(chan []byte)
}
pollData := make(chan []byte, 1)
defer close(pollData)
cancelKeepAlive := make(chan []byte, 1)
defer close(cancelKeepAlive)
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
pollData <- *data
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
// TODO: Why does this block?
go h.notifyChangesToPeers(&m)
h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Finished stream, closing PollNetMap session")
}
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) {
for {
select {
case <-cancel:
return
default:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
pollData <- *data
time.Sleep(60 * time.Second)
}
}
}
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
log.Trace().
Str("func", "getMapResponse").
@ -542,7 +353,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("ip", ip.String()).
Msgf("Assining %s to %s", ip, m.Name)
Msgf("Assigning %s to %s", ip, m.Name)
m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String()

37
app.go
View file

@ -58,7 +58,10 @@ type Headscale struct {
aclPolicy *ACLPolicy
aclRules *[]tailcfg.FilterRule
clientsPolling sync.Map
clientsUpdateChannels sync.Map
clientsUpdateChannelMutex sync.Mutex
lastStateChange sync.Map
}
// NewHeadscale returns the Headscale app
@ -165,9 +168,18 @@ func (h *Headscale) Serve() error {
r.POST("/machine/:id", h.RegistrationHandler)
var err error
timeout := 30 * time.Second
go h.watchForKVUpdates(5000)
go h.expireEphemeralNodes(5000)
s := &http.Server{
Addr: h.cfg.Addr,
Handler: r,
ReadTimeout: timeout,
WriteTimeout: timeout,
}
if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
@ -182,6 +194,8 @@ func (h *Headscale) Serve() error {
Addr: h.cfg.Addr,
TLSConfig: m.TLSConfig(),
Handler: r,
ReadTimeout: timeout,
WriteTimeout: timeout,
}
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
@ -206,12 +220,29 @@ func (h *Headscale) Serve() error {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
}
err = r.Run(h.cfg.Addr)
err = s.ListenAndServe()
} else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
}
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
err = s.ListenAndServeTLS(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
}
return err
}
func (h *Headscale) setLastStateChangeToNow(namespace string) {
now := time.Now().UTC()
h.lastStateChange.Store(namespace, now)
}
func (h *Headscale) getLastStateChange(namespace string) time.Time {
if wrapped, ok := h.lastStateChange.Load(namespace); ok {
lastChange, _ := wrapped.(time.Time)
return lastChange
}
now := time.Now().UTC()
h.lastStateChange.Store(namespace, now)
return now
}

View file

@ -39,7 +39,7 @@ func LoadConfig(path string) error {
viper.SetDefault("ip_prefix", "100.64.0.0/10")
viper.SetDefault("log_level", "debug")
viper.SetDefault("log_level", "info")
err := viper.ReadInConfig()
if err != nil {

View file

@ -4,10 +4,13 @@ package headscale
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"path"
"strings"
"testing"
"time"
@ -20,23 +23,48 @@ import (
"inet.af/netaddr"
)
type IntegrationTestSuite struct {
suite.Suite
}
func TestIntegrationTestSuite(t *testing.T) {
suite.Run(t, new(IntegrationTestSuite))
}
var integrationTmpDir string
var ih Headscale
var pool dockertest.Pool
var network dockertest.Network
var headscale dockertest.Resource
var tailscaleCount int = 5
var tailscaleCount int = 25
var tailscales map[string]dockertest.Resource
type IntegrationTestSuite struct {
suite.Suite
stats *suite.SuiteInformation
}
func TestIntegrationTestSuite(t *testing.T) {
s := new(IntegrationTestSuite)
suite.Run(t, s)
// HandleStats, which allows us to check if we passed and save logs
// is called after TearDown, so we cannot tear down containers before
// we have potentially saved the logs.
for _, tailscale := range tailscales {
if err := pool.Purge(&tailscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
if !s.stats.Passed() {
err := saveLog(&headscale, "test_output")
if err != nil {
log.Printf("Could not save log: %s\n", err)
}
}
if err := pool.Purge(&headscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
if err := network.Close(); err != nil {
log.Printf("Could not close network: %s\n", err)
}
}
func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer
@ -62,6 +90,48 @@ func executeCommand(resource *dockertest.Resource, cmd []string) (string, error)
return stdout.String(), nil
}
func saveLog(resource *dockertest.Resource, basePath string) error {
err := os.MkdirAll(basePath, os.ModePerm)
if err != nil {
return err
}
var stdout bytes.Buffer
var stderr bytes.Buffer
err = pool.Client.Logs(
docker.LogsOptions{
Context: context.TODO(),
Container: resource.Container.ID,
OutputStream: &stdout,
ErrorStream: &stderr,
Tail: "all",
RawTerminal: false,
Stdout: true,
Stderr: true,
Follow: false,
Timestamps: false,
},
)
if err != nil {
return err
}
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0644)
if err != nil {
return err
}
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0644)
if err != nil {
return err
}
return nil
}
func dockerRestartPolicy(config *docker.HostConfig) {
// set AutoRemove to true so that stopped container goes away by itself
config.AutoRemove = true
@ -115,7 +185,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
PortBindings: map[docker.Port][]docker.PortBinding{
"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
},
Env: []string{},
}
fmt.Println("Creating headscale container")
@ -134,7 +203,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
Name: hostname,
Networks: []*dockertest.Network{&network},
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
Env: []string{},
}
if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil {
@ -145,7 +213,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
fmt.Printf("Created %s container\n", hostname)
}
// TODO: Replace this logic with something that can be detected on Github Actions
fmt.Println("Waiting for headscale to be ready")
hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
@ -197,23 +264,14 @@ func (s *IntegrationTestSuite) SetupSuite() {
// The nodes need a bit of time to get their updated maps from headscale
// TODO: See if we can have a more deterministic wait here.
time.Sleep(20 * time.Second)
time.Sleep(60 * time.Second)
}
func (s *IntegrationTestSuite) TearDownSuite() {
if err := pool.Purge(&headscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
for _, tailscale := range tailscales {
if err := pool.Purge(&tailscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
if err := network.Close(); err != nil {
log.Printf("Could not close network: %s\n", err)
}
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
s.stats = stats
}
func (s *IntegrationTestSuite) TestListNodes() {
@ -295,7 +353,15 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
// We currently cant ping ourselves, so skip that.
if peername != hostname {
command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()}
// We are only interested in "direct ping" which means what we
// might need a couple of more attempts before reaching the node.
command := []string{
"tailscale", "ping",
"--timeout=1s",
"--c=20",
"--until-direct=true",
ip.String(),
}
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
result, err := executeCommand(

View file

@ -7,5 +7,5 @@
"db_type": "sqlite3",
"db_path": "/tmp/integration_test_db.sqlite3",
"acl_policy_path": "",
"log_level": "trace"
"log_level": "debug"
}

View file

@ -2,6 +2,7 @@ package headscale
import (
"encoding/json"
"errors"
"fmt"
"sort"
"strconv"
@ -32,6 +33,7 @@ type Machine struct {
AuthKey *PreAuthKey
LastSeen *time.Time
LastSuccessfulUpdate *time.Time
Expiry *time.Time
HostInfo datatypes.JSON
@ -211,6 +213,15 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil
}
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(m *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil {
return result.Error
}
return nil
}
// DeleteMachine softs deletes a Machine from the database
func (h *Headscale) DeleteMachine(m *Machine) error {
m.Registered = false
@ -251,21 +262,110 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
func (h *Headscale) notifyChangesToPeers(m *Machine) {
peers, _ := h.getPeers(*m)
for _, p := range *peers {
pUp, ok := h.clientsPolling.Load(uint64(p.ID))
if ok {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", m.Name).
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
pUp.(chan []byte) <- []byte{}
} else {
err := h.sendRequestOnUpdateChannel(p)
if err != nil {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", m.Name).
Str("peer", p.Name).
Msgf("Peer %s does not appear to be polling", p.Name)
}
log.Trace().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0])
}
}
func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
var updateChan chan struct{}
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
updateChan = unwrapped
} else {
log.Error().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Failed to convert update channel to struct{}")
}
} else {
log.Debug().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Update channel not found, creating")
updateChan = make(chan struct{})
h.clientsUpdateChannels.Store(m.ID, updateChan)
}
return updateChan
}
func (h *Headscale) closeUpdateChannel(m *Machine) {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
close(unwrapped)
}
}
h.clientsUpdateChannels.Delete(m.ID)
}
func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
if ok {
log.Info().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Notifying peer %s", m.Name)
if update, ok := pUp.(chan struct{}); ok {
log.Trace().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Update channel is %#v", update)
update <- struct{}{}
log.Trace().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Notified machine %s", m.Name)
}
} else {
log.Info().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Machine %s does not appear to be polling", m.Name)
return errors.New("machine does not seem to be polling")
}
return nil
}
func (h *Headscale) isOutdated(m *Machine) bool {
err := h.UpdateMachine(m)
if err != nil {
return true
}
lastChange := h.getLastStateChange(m.Namespace.Name)
log.Trace().
Str("func", "keepAlive").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name)
return m.LastSuccessfulUpdate.Before(lastChange)
}

382
poll.go
View file

@ -1,38 +1,225 @@
package headscale
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"gorm.io/datatypes"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
return
}
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
h.db.Save(&m)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Probably interested in a DERP map")
c.Data(200, "application/json; charset=utf-8", *data)
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow(m.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Loading or creating update channel")
updateChan := h.getOrOpenUpdateChannel(&m)
pollDataChan := make(chan []byte)
// defer close(pollData)
keepAliveChan := make(chan []byte)
cancelKeepAlive := make(chan struct{})
defer close(cancelKeepAlive)
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
go h.notifyChangesToPeers(&m)
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
return
}
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
go func() { pollDataChan <- *data }()
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
go h.notifyChangesToPeers(&m)
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Finished stream, closing PollNetMap session")
}
// PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream(
c *gin.Context,
m Machine,
req tailcfg.MapRequest,
mKey wgkey.Key,
pollData chan []byte,
update chan []byte,
cancelKeepAlive chan []byte,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan <-chan struct{},
cancelKeepAlive chan struct{},
) {
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
c.Stream(func(w io.Writer) bool {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msg("Waiting for data to stream...")
select {
case data := <-pollData:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select {
case data := <-pollDataChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := w.Write(data)
@ -40,34 +227,104 @@ func (h *Headscale) PollNetMapStream(
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending pollData")
return true
case data := <-keepAliveChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := w.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
h.db.Save(&m)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending pollData")
Msg("Machine updated successfully after sending keep alive")
return true
case <-update:
case <-updateChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Msg("Received a request for update")
if h.isOutdated(&m) {
log.Debug().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msg("Received a request for update")
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", m.Name)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
}
@ -76,9 +333,43 @@ func (h *Headscale) PollNetMapStream(
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Msg("Updated Map has been sent")
// Keep track of the last successful update,
// we sometimes end in a state were the update
// is not picked up by a client and we use this
// to determine if we should "force" an update.
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
} else {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("%s is up to date", m.Name)
}
return true
case <-c.Request.Context().Done():
@ -86,13 +377,78 @@ func (h *Headscale) PollNetMapStream(
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
h.db.Save(&m)
cancelKeepAlive <- []byte{}
h.clientsPolling.Delete(m.ID)
close(update)
cancelKeepAlive <- struct{}{}
h.closeUpdateChannel(&m)
close(pollDataChan)
close(keepAliveChan)
return false
}
})
}
func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{},
keepAliveChan chan<- []byte,
mKey wgkey.Key,
req tailcfg.MapRequest,
m Machine,
) {
keepAliveTicker := time.NewTicker(60 * time.Second)
updateCheckerTicker := time.NewTicker(30 * time.Second)
for {
select {
case <-cancelChan:
return
case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
keepAliveChan <- *data
case <-updateCheckerTicker.C:
// Send an update request regardless of outdated or not, if data is sent
// to the node is determined in the updateChan consumer block
n, _ := m.toNode()
err := h.sendRequestOnUpdateChannel(n)
if err != nil {
log.Error().
Str("func", "keepAlive").
Str("machine", m.Name).
Err(err).
Msgf("Failed to send update request to %s", m.Name)
}
}
}
}