Merge pull request #84 from kradalby/integration-tests-ci

Improve logic to keep nodes up to date with the network state
This commit is contained in:
Juan Font 2021-08-23 09:42:07 +02:00 committed by GitHub
commit 74d2fe1baa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 656 additions and 267 deletions

23
.github/workflows/test-integration.yml vendored Normal file
View file

@ -0,0 +1,23 @@
name: CI
on: [pull_request]
jobs:
# The "build" workflow
integration-test:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3"
- name: Run Integration tests
run: go test -tags integration -timeout 30m

2
.gitignore vendored
View file

@ -19,3 +19,5 @@ config.json
*.key *.key
/db.sqlite /db.sqlite
*.sqlite3 *.sqlite3
test_output/

203
api.go
View file

@ -13,7 +13,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"gorm.io/datatypes"
"gorm.io/gorm" "gorm.io/gorm"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@ -82,14 +81,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return return
} }
now := time.Now().UTC()
var m Machine var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
m = Machine{ m = Machine{
Expiry: &req.Expiry, Expiry: &req.Expiry,
MachineKey: mKey.HexString(), MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname, Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(), NodeKey: wgkey.Key(req.NodeKey).HexString(),
LastSuccessfulUpdate: &now,
} }
if err := h.db.Create(&m).Error; err != nil { if err := h.db.Create(&m).Error; err != nil {
log.Error(). log.Error().
@ -215,196 +216,6 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
} }
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
return
}
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
h.db.Save(&m)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Asking for DERP map")
c.Data(200, "application/json; charset=utf-8", *data)
return
}
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data)
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
return
}
// Only create update channel if it has not been created
var update chan []byte
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Creating or loading update channel")
if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok {
update = result.(chan []byte)
}
pollData := make(chan []byte, 1)
defer close(pollData)
cancelKeepAlive := make(chan []byte, 1)
defer close(cancelKeepAlive)
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
pollData <- *data
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
// TODO: Why does this block?
go h.notifyChangesToPeers(&m)
h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Finished stream, closing PollNetMap session")
}
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) {
for {
select {
case <-cancel:
return
default:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
pollData <- *data
time.Sleep(60 * time.Second)
}
}
}
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) { func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
log.Trace(). log.Trace().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
@ -542,7 +353,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
Str("func", "handleAuthKey"). Str("func", "handleAuthKey").
Str("machine", m.Name). Str("machine", m.Name).
Str("ip", ip.String()). Str("ip", ip.String()).
Msgf("Assining %s to %s", ip, m.Name) Msgf("Assigning %s to %s", ip, m.Name)
m.AuthKeyID = uint(pak.ID) m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String() m.IPAddress = ip.String()

43
app.go
View file

@ -58,7 +58,10 @@ type Headscale struct {
aclPolicy *ACLPolicy aclPolicy *ACLPolicy
aclRules *[]tailcfg.FilterRule aclRules *[]tailcfg.FilterRule
clientsPolling sync.Map clientsUpdateChannels sync.Map
clientsUpdateChannelMutex sync.Mutex
lastStateChange sync.Map
} }
// NewHeadscale returns the Headscale app // NewHeadscale returns the Headscale app
@ -165,9 +168,18 @@ func (h *Headscale) Serve() error {
r.POST("/machine/:id", h.RegistrationHandler) r.POST("/machine/:id", h.RegistrationHandler)
var err error var err error
timeout := 30 * time.Second
go h.watchForKVUpdates(5000) go h.watchForKVUpdates(5000)
go h.expireEphemeralNodes(5000) go h.expireEphemeralNodes(5000)
s := &http.Server{
Addr: h.cfg.Addr,
Handler: r,
ReadTimeout: timeout,
WriteTimeout: timeout,
}
if h.cfg.TLSLetsEncryptHostname != "" { if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
@ -179,9 +191,11 @@ func (h *Headscale) Serve() error {
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir), Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
} }
s := &http.Server{ s := &http.Server{
Addr: h.cfg.Addr, Addr: h.cfg.Addr,
TLSConfig: m.TLSConfig(), TLSConfig: m.TLSConfig(),
Handler: r, Handler: r,
ReadTimeout: timeout,
WriteTimeout: timeout,
} }
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" { if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
@ -206,12 +220,29 @@ func (h *Headscale) Serve() error {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") { if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://") log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
} }
err = r.Run(h.cfg.Addr) err = s.ListenAndServe()
} else { } else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
} }
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath) err = s.ListenAndServeTLS(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
} }
return err return err
} }
func (h *Headscale) setLastStateChangeToNow(namespace string) {
now := time.Now().UTC()
h.lastStateChange.Store(namespace, now)
}
func (h *Headscale) getLastStateChange(namespace string) time.Time {
if wrapped, ok := h.lastStateChange.Load(namespace); ok {
lastChange, _ := wrapped.(time.Time)
return lastChange
}
now := time.Now().UTC()
h.lastStateChange.Store(namespace, now)
return now
}

View file

@ -39,7 +39,7 @@ func LoadConfig(path string) error {
viper.SetDefault("ip_prefix", "100.64.0.0/10") viper.SetDefault("ip_prefix", "100.64.0.0/10")
viper.SetDefault("log_level", "debug") viper.SetDefault("log_level", "info")
err := viper.ReadInConfig() err := viper.ReadInConfig()
if err != nil { if err != nil {

View file

@ -4,10 +4,13 @@ package headscale
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
"path"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -20,23 +23,48 @@ import (
"inet.af/netaddr" "inet.af/netaddr"
) )
type IntegrationTestSuite struct {
suite.Suite
}
func TestIntegrationTestSuite(t *testing.T) {
suite.Run(t, new(IntegrationTestSuite))
}
var integrationTmpDir string var integrationTmpDir string
var ih Headscale var ih Headscale
var pool dockertest.Pool var pool dockertest.Pool
var network dockertest.Network var network dockertest.Network
var headscale dockertest.Resource var headscale dockertest.Resource
var tailscaleCount int = 5 var tailscaleCount int = 25
var tailscales map[string]dockertest.Resource var tailscales map[string]dockertest.Resource
type IntegrationTestSuite struct {
suite.Suite
stats *suite.SuiteInformation
}
func TestIntegrationTestSuite(t *testing.T) {
s := new(IntegrationTestSuite)
suite.Run(t, s)
// HandleStats, which allows us to check if we passed and save logs
// is called after TearDown, so we cannot tear down containers before
// we have potentially saved the logs.
for _, tailscale := range tailscales {
if err := pool.Purge(&tailscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
if !s.stats.Passed() {
err := saveLog(&headscale, "test_output")
if err != nil {
log.Printf("Could not save log: %s\n", err)
}
}
if err := pool.Purge(&headscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
if err := network.Close(); err != nil {
log.Printf("Could not close network: %s\n", err)
}
}
func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) { func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
var stdout bytes.Buffer var stdout bytes.Buffer
var stderr bytes.Buffer var stderr bytes.Buffer
@ -62,6 +90,48 @@ func executeCommand(resource *dockertest.Resource, cmd []string) (string, error)
return stdout.String(), nil return stdout.String(), nil
} }
func saveLog(resource *dockertest.Resource, basePath string) error {
err := os.MkdirAll(basePath, os.ModePerm)
if err != nil {
return err
}
var stdout bytes.Buffer
var stderr bytes.Buffer
err = pool.Client.Logs(
docker.LogsOptions{
Context: context.TODO(),
Container: resource.Container.ID,
OutputStream: &stdout,
ErrorStream: &stderr,
Tail: "all",
RawTerminal: false,
Stdout: true,
Stderr: true,
Follow: false,
Timestamps: false,
},
)
if err != nil {
return err
}
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0644)
if err != nil {
return err
}
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0644)
if err != nil {
return err
}
return nil
}
func dockerRestartPolicy(config *docker.HostConfig) { func dockerRestartPolicy(config *docker.HostConfig) {
// set AutoRemove to true so that stopped container goes away by itself // set AutoRemove to true so that stopped container goes away by itself
config.AutoRemove = true config.AutoRemove = true
@ -115,7 +185,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
PortBindings: map[docker.Port][]docker.PortBinding{ PortBindings: map[docker.Port][]docker.PortBinding{
"8080/tcp": []docker.PortBinding{{HostPort: "8080"}}, "8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
}, },
Env: []string{},
} }
fmt.Println("Creating headscale container") fmt.Println("Creating headscale container")
@ -134,7 +203,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
Name: hostname, Name: hostname,
Networks: []*dockertest.Network{&network}, Networks: []*dockertest.Network{&network},
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"}, Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
Env: []string{},
} }
if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil { if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil {
@ -145,7 +213,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
fmt.Printf("Created %s container\n", hostname) fmt.Printf("Created %s container\n", hostname)
} }
// TODO: Replace this logic with something that can be detected on Github Actions
fmt.Println("Waiting for headscale to be ready") fmt.Println("Waiting for headscale to be ready")
hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp")) hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
@ -197,23 +264,14 @@ func (s *IntegrationTestSuite) SetupSuite() {
// The nodes need a bit of time to get their updated maps from headscale // The nodes need a bit of time to get their updated maps from headscale
// TODO: See if we can have a more deterministic wait here. // TODO: See if we can have a more deterministic wait here.
time.Sleep(20 * time.Second) time.Sleep(60 * time.Second)
} }
func (s *IntegrationTestSuite) TearDownSuite() { func (s *IntegrationTestSuite) TearDownSuite() {
if err := pool.Purge(&headscale); err != nil { }
log.Printf("Could not purge resource: %s\n", err)
}
for _, tailscale := range tailscales { func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
if err := pool.Purge(&tailscale); err != nil { s.stats = stats
log.Printf("Could not purge resource: %s\n", err)
}
}
if err := network.Close(); err != nil {
log.Printf("Could not close network: %s\n", err)
}
} }
func (s *IntegrationTestSuite) TestListNodes() { func (s *IntegrationTestSuite) TestListNodes() {
@ -295,7 +353,15 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
// We currently cant ping ourselves, so skip that. // We currently cant ping ourselves, so skip that.
if peername != hostname { if peername != hostname {
command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()} // We are only interested in "direct ping" which means what we
// might need a couple of more attempts before reaching the node.
command := []string{
"tailscale", "ping",
"--timeout=1s",
"--c=20",
"--until-direct=true",
ip.String(),
}
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
result, err := executeCommand( result, err := executeCommand(

View file

@ -7,5 +7,5 @@
"db_type": "sqlite3", "db_type": "sqlite3",
"db_path": "/tmp/integration_test_db.sqlite3", "db_path": "/tmp/integration_test_db.sqlite3",
"acl_policy_path": "", "acl_policy_path": "",
"log_level": "trace" "log_level": "debug"
} }

View file

@ -2,6 +2,7 @@ package headscale
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"sort" "sort"
"strconv" "strconv"
@ -31,8 +32,9 @@ type Machine struct {
AuthKeyID uint AuthKeyID uint
AuthKey *PreAuthKey AuthKey *PreAuthKey
LastSeen *time.Time LastSeen *time.Time
Expiry *time.Time LastSuccessfulUpdate *time.Time
Expiry *time.Time
HostInfo datatypes.JSON HostInfo datatypes.JSON
Endpoints datatypes.JSON Endpoints datatypes.JSON
@ -211,6 +213,15 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil return &m, nil
} }
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(m *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil {
return result.Error
}
return nil
}
// DeleteMachine softs deletes a Machine from the database // DeleteMachine softs deletes a Machine from the database
func (h *Headscale) DeleteMachine(m *Machine) error { func (h *Headscale) DeleteMachine(m *Machine) error {
m.Registered = false m.Registered = false
@ -251,21 +262,110 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
func (h *Headscale) notifyChangesToPeers(m *Machine) { func (h *Headscale) notifyChangesToPeers(m *Machine) {
peers, _ := h.getPeers(*m) peers, _ := h.getPeers(*m)
for _, p := range *peers { for _, p := range *peers {
pUp, ok := h.clientsPolling.Load(uint64(p.ID)) log.Info().
if ok { Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
err := h.sendRequestOnUpdateChannel(p)
if err != nil {
log.Info(). log.Info().
Str("func", "notifyChangesToPeers"). Str("func", "notifyChangesToPeers").
Str("machine", m.Name). Str("machine", m.Name).
Str("peer", m.Name). Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
pUp.(chan []byte) <- []byte{}
} else {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", m.Name).
Msgf("Peer %s does not appear to be polling", p.Name) Msgf("Peer %s does not appear to be polling", p.Name)
} }
log.Trace().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0])
} }
} }
func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
var updateChan chan struct{}
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
updateChan = unwrapped
} else {
log.Error().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Failed to convert update channel to struct{}")
}
} else {
log.Debug().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Update channel not found, creating")
updateChan = make(chan struct{})
h.clientsUpdateChannels.Store(m.ID, updateChan)
}
return updateChan
}
func (h *Headscale) closeUpdateChannel(m *Machine) {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
close(unwrapped)
}
}
h.clientsUpdateChannels.Delete(m.ID)
}
func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
if ok {
log.Info().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Notifying peer %s", m.Name)
if update, ok := pUp.(chan struct{}); ok {
log.Trace().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Update channel is %#v", update)
update <- struct{}{}
log.Trace().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Notified machine %s", m.Name)
}
} else {
log.Info().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Machine %s does not appear to be polling", m.Name)
return errors.New("machine does not seem to be polling")
}
return nil
}
func (h *Headscale) isOutdated(m *Machine) bool {
err := h.UpdateMachine(m)
if err != nil {
return true
}
lastChange := h.getLastStateChange(m.Namespace.Name)
log.Trace().
Str("func", "keepAlive").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name)
return m.LastSuccessfulUpdate.Before(lastChange)
}

404
poll.go
View file

@ -1,38 +1,225 @@
package headscale package headscale
import ( import (
"encoding/json"
"errors"
"io" "io"
"net/http"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/datatypes"
"gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
return
}
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
h.db.Save(&m)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Probably interested in a DERP map")
c.Data(200, "application/json; charset=utf-8", *data)
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow(m.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Loading or creating update channel")
updateChan := h.getOrOpenUpdateChannel(&m)
pollDataChan := make(chan []byte)
// defer close(pollData)
keepAliveChan := make(chan []byte)
cancelKeepAlive := make(chan struct{})
defer close(cancelKeepAlive)
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
go h.notifyChangesToPeers(&m)
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
return
}
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
go func() { pollDataChan <- *data }()
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
go h.notifyChangesToPeers(&m)
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Finished stream, closing PollNetMap session")
}
// PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream( func (h *Headscale) PollNetMapStream(
c *gin.Context, c *gin.Context,
m Machine, m Machine,
req tailcfg.MapRequest, req tailcfg.MapRequest,
mKey wgkey.Key, mKey wgkey.Key,
pollData chan []byte, pollDataChan chan []byte,
update chan []byte, keepAliveChan chan []byte,
cancelKeepAlive chan []byte, updateChan <-chan struct{},
cancelKeepAlive chan struct{},
) { ) {
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Msg("Waiting for data to stream...") Msg("Waiting for data to stream...")
select {
case data := <-pollData: log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select {
case data := <-pollDataChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending data received via pollData channel") Msg("Sending data received via pollData channel")
_, err := w.Write(data) _, err := w.Write(data)
@ -40,44 +227,148 @@ func (h *Headscale) PollNetMapStream(
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot write data") Msg("Cannot write data")
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Data from pollData channel written successfully") Msg("Data from pollData channel written successfully")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending pollData")
return true
case data := <-keepAliveChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := w.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now m.LastSeen = &now
h.db.Save(&m) h.db.Save(&m)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine updated successfully after sending pollData") Msg("Machine updated successfully after sending keep alive")
return true return true
case <-update: case <-updateChan:
log.Debug(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Str("channel", "update").
Msg("Received a request for update") Msg("Received a request for update")
data, err := h.getMapResponse(mKey, req, m) if h.isOutdated(&m) {
if err != nil { log.Debug().
log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Err(err). Time("last_successful_update", *m.LastSuccessfulUpdate).
Msg("Could not get the map update") Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
} Msgf("There has been updates since the last successful update to %s", m.Name)
_, err = w.Write(*data) data, err := h.getMapResponse(mKey, req, m)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(*data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
}
log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Err(err). Str("channel", "update").
Msg("Could not write the map response") Msg("Updated Map has been sent")
// Keep track of the last successful update,
// we sometimes end in a state were the update
// is not picked up by a client and we use this
// to determine if we should "force" an update.
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
} else {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("%s is up to date", m.Name)
} }
return true return true
@ -86,13 +377,78 @@ func (h *Headscale) PollNetMapStream(
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", m.Name).
Msg("The client has closed the connection") Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now m.LastSeen = &now
h.db.Save(&m) h.db.Save(&m)
cancelKeepAlive <- []byte{}
h.clientsPolling.Delete(m.ID) cancelKeepAlive <- struct{}{}
close(update)
h.closeUpdateChannel(&m)
close(pollDataChan)
close(keepAliveChan)
return false return false
} }
}) })
} }
func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{},
keepAliveChan chan<- []byte,
mKey wgkey.Key,
req tailcfg.MapRequest,
m Machine,
) {
keepAliveTicker := time.NewTicker(60 * time.Second)
updateCheckerTicker := time.NewTicker(30 * time.Second)
for {
select {
case <-cancelChan:
return
case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
keepAliveChan <- *data
case <-updateCheckerTicker.C:
// Send an update request regardless of outdated or not, if data is sent
// to the node is determined in the updateChan consumer block
n, _ := m.toNode()
err := h.sendRequestOnUpdateChannel(n)
if err != nil {
log.Error().
Str("func", "keepAlive").
Str("machine", m.Name).
Err(err).
Msgf("Failed to send update request to %s", m.Name)
}
}
}
}