Fix rest of var name in main code

This commit is contained in:
Kristoffer Dalby 2021-11-15 16:15:50 +00:00
parent 471c0b4993
commit 333be80f9c
No known key found for this signature in database
GPG key ID: 09F62DC067465735
6 changed files with 137 additions and 130 deletions

6
db.go
View file

@ -100,18 +100,18 @@ func (h *Headscale) getValue(key string) (string, error) {
// setValue sets value for the given key in KV. // setValue sets value for the given key in KV.
func (h *Headscale) setValue(key string, value string) error { func (h *Headscale) setValue(key string, value string) error {
kv := KV{ keyValue := KV{
Key: key, Key: key,
Value: value, Value: value,
} }
if _, err := h.getValue(key); err == nil { if _, err := h.getValue(key); err == nil {
h.db.Model(&kv).Where("key = ?", key).Update("value", value) h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
return nil return nil
} }
h.db.Create(kv) h.db.Create(keyValue)
return nil return nil
} }

View file

@ -526,7 +526,7 @@ func (machine Machine) toNode(
hostname = machine.Name hostname = machine.Name
} }
n := tailcfg.Node{ node := tailcfg.Node{
ID: tailcfg.NodeID(machine.ID), // this is the actual ID ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID( StableID: tailcfg.StableNodeID(
strconv.FormatUint(machine.ID, BASE_10), strconv.FormatUint(machine.ID, BASE_10),
@ -551,7 +551,7 @@ func (machine Machine) toNode(
Capabilities: []string{tailcfg.CapabilityFileSharing}, Capabilities: []string{tailcfg.CapabilityFileSharing},
} }
return &n, nil return &node, nil
} }
func (machine *Machine) toProto() *v1.Machine { func (machine *Machine) toProto() *v1.Machine {

View file

@ -76,15 +76,15 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
return return
} }
b := make([]byte, RANDOM_BYTE_SIZE) randomBlob := make([]byte, RANDOM_BYTE_SIZE)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(randomBlob); err != nil {
log.Error().Msg("could not read 16 bytes from rand") log.Error().Msg("could not read 16 bytes from rand")
ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return return
} }
stateStr := hex.EncodeToString(b)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, mKeyStr, OIDC_STATE_CACHE_EXPIRATION) h.oidcStateCache.Set(stateStr, mKeyStr, OIDC_STATE_CACHE_EXPIRATION)

219
poll.go
View file

@ -29,20 +29,20 @@ const (
// only after their first request (marked with the ReadOnly field). // only after their first request (marked with the ReadOnly field).
// //
// At this moment the updates are sent in a quite horrendous way, but they kinda work. // At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) { func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Msg("PollNetMapHandler called") Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(ctx.Request.Body)
mKeyStr := c.Param("id") mKeyStr := ctx.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr) mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Err(err). Err(err).
Msg("Cannot parse client key") Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "") ctx.String(http.StatusBadRequest, "")
return return
} }
@ -53,36 +53,36 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
c.String(http.StatusBadRequest, "") ctx.String(http.StatusBadRequest, "")
return return
} }
m, err := h.GetMachineByMachineKey(mKey.HexString()) machine, err := h.GetMachineByMachineKey(mKey.HexString())
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "") ctx.String(http.StatusUnauthorized, "")
return return
} }
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString()) Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString())
c.String(http.StatusInternalServerError, "") ctx.String(http.StatusInternalServerError, "")
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Found machine in database") Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo) hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname machine.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo) machine.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString() machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC() now := time.Now().UTC()
// From Tailscale client: // From Tailscale client:
@ -95,20 +95,20 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// before their first real endpoint update. // before their first real endpoint update.
if !req.ReadOnly { if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints) endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints) machine.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now machine.LastSeen = &now
} }
h.db.Save(&m) h.db.Save(&machine)
data, err := h.getMapResponse(mKey, req, m) data, err := h.getMapResponse(mKey, req, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Err(err). Err(err).
Msg("Failed to get Map response") Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(") ctx.String(http.StatusInternalServerError, ":(")
return return
} }
@ -120,8 +120,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug(). log.Debug().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Bool("readOnly", req.ReadOnly). Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers). Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream). Bool("stream", req.Stream).
@ -130,16 +130,16 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
if req.ReadOnly { if req.ReadOnly {
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client is starting up. Probably interested in a DERP map") Msg("Client is starting up. Probably interested in a DERP map")
c.Data(http.StatusOK, "application/json; charset=utf-8", data) ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
return return
} }
// There has been an update to _any_ of the nodes that the other nodes would // There has been an update to _any_ of the nodes that the other nodes would
// need to know about // need to know about
h.setLastStateChangeToNow(m.Namespace.Name) h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating // The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll // peers via longpoll
@ -147,8 +147,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// Only create update channel if it has not been created // Only create update channel if it has not been created
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Loading or creating update channel") Msg("Loading or creating update channel")
updateChan := make(chan struct{}) updateChan := make(chan struct{})
@ -162,13 +162,13 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
if req.OmitPeers && !req.Stream { if req.OmitPeers && !req.Stream {
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client sent endpoint update and is ok with a response without peer list") Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(http.StatusOK, "application/json; charset=utf-8", data) ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
// It sounds like we should update the nodes when we have received a endpoint update // It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so. // even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update"). updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update").
Inc() Inc()
go func() { updateChan <- struct{}{} }() go func() { updateChan <- struct{}{} }()
@ -176,34 +176,34 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
} else if req.OmitPeers && req.Stream { } else if req.OmitPeers && req.Stream {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Ignoring request, don't know how to handle it") Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "") ctx.String(http.StatusBadRequest, "")
return return
} }
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Client is ready to access the tailnet") Msg("Client is ready to access the tailnet")
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Sending initial map") Msg("Sending initial map")
go func() { pollDataChan <- data }() go func() { pollDataChan <- data }()
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Notifying peers") Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update"). updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update").
Inc() Inc()
go func() { updateChan <- struct{}{} }() go func() { updateChan <- struct{}{} }()
h.PollNetMapStream( h.PollNetMapStream(
c, ctx,
m, machine,
req, req,
mKey, mKey,
pollDataChan, pollDataChan,
@ -213,8 +213,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
) )
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", c.Param("id")). Str("id", ctx.Param("id")).
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Finished stream, closing PollNetMap session") Msg("Finished stream, closing PollNetMap session")
} }
@ -222,33 +222,40 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// stream logic, ensuring we communicate updates and data // stream logic, ensuring we communicate updates and data
// to the connected clients. // to the connected clients.
func (h *Headscale) PollNetMapStream( func (h *Headscale) PollNetMapStream(
c *gin.Context, ctx *gin.Context,
m *Machine, machine *Machine,
req tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
mKey wgkey.Key, machineKey wgkey.Key,
pollDataChan chan []byte, pollDataChan chan []byte,
keepAliveChan chan []byte, keepAliveChan chan []byte,
updateChan chan struct{}, updateChan chan struct{},
cancelKeepAlive chan struct{}, cancelKeepAlive chan struct{},
) { ) {
go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m) go h.scheduledPollWorker(
cancelKeepAlive,
updateChan,
keepAliveChan,
machineKey,
mapRequest,
machine,
)
c.Stream(func(writer io.Writer) bool { ctx.Stream(func(writer io.Writer) bool {
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Waiting for data to stream...") Msg("Waiting for data to stream...")
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select { select {
case data := <-pollDataChan: case data := <-pollDataChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending data received via pollData channel") Msg("Sending data received via pollData channel")
@ -256,7 +263,7 @@ func (h *Headscale) PollNetMapStream(
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot write data") Msg("Cannot write data")
@ -265,33 +272,33 @@ func (h *Headscale) PollNetMapStream(
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Data from pollData channel written successfully") Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachine(m) err = h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update machine from database")
} }
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now machine.LastSeen = &now
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name). lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
Set(float64(now.Unix())) Set(float64(now.Unix()))
m.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
h.db.Save(&m) h.db.Save(&machine)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending pollData") Msg("Machine entry in database updated successfully after sending pollData")
@ -301,7 +308,7 @@ func (h *Headscale) PollNetMapStream(
case data := <-keepAliveChan: case data := <-keepAliveChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending keep alive message") Msg("Sending keep alive message")
@ -309,7 +316,7 @@ func (h *Headscale) PollNetMapStream(
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot write keep alive message") Msg("Cannot write keep alive message")
@ -318,28 +325,28 @@ func (h *Headscale) PollNetMapStream(
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Keep alive sent successfully") Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachine(m) err = h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update machine from database")
} }
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now machine.LastSeen = &now
h.db.Save(&m) h.db.Save(&machine)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive") Msg("Machine updated successfully after sending keep alive")
@ -349,23 +356,23 @@ func (h *Headscale) PollNetMapStream(
case <-updateChan: case <-updateChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Msg("Received a request for update") Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name). updateRequestsReceivedOnChannel.WithLabelValues(machine.Name, machine.Namespace.Name).
Inc() Inc()
if h.isOutdated(m) { if h.isOutdated(machine) {
log.Debug(). log.Debug().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate). Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", m.Name) Msgf("There has been updates since the last successful update to %s", machine.Name)
data, err := h.getMapResponse(mKey, req, m) data, err := h.getMapResponse(machineKey, mapRequest, machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Could not get the map update") Msg("Could not get the map update")
@ -374,21 +381,21 @@ func (h *Headscale) PollNetMapStream(
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Could not write the map response") Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed"). updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "failed").
Inc() Inc()
return false return false
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Msg("Updated Map has been sent") Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success"). updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "success").
Inc() Inc()
// Keep track of the last successful update, // Keep track of the last successful update,
@ -398,64 +405,64 @@ func (h *Headscale) PollNetMapStream(
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachine(m) err = h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update machine from database")
} }
now := time.Now().UTC() now := time.Now().UTC()
lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name). lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
Set(float64(now.Unix())) Set(float64(now.Unix()))
m.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
h.db.Save(&m) h.db.Save(&machine)
} else { } else {
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate). Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("%s is up to date", m.Name) Msgf("%s is up to date", machine.Name)
} }
return true return true
case <-c.Request.Context().Done(): case <-ctx.Request.Context().Done():
log.Info(). log.Info().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("The client has closed the connection") Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions // TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err := h.UpdateMachine(m) err := h.UpdateMachine(machine)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update machine from database")
} }
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now machine.LastSeen = &now
h.db.Save(&m) h.db.Save(&machine)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Msg("Cancelling keepAlive channel") Msg("Cancelling keepAlive channel")
cancelKeepAlive <- struct{}{} cancelKeepAlive <- struct{}{}
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Msg("Closing update channel") Msg("Closing update channel")
// h.closeUpdateChannel(m) // h.closeUpdateChannel(m)
@ -463,14 +470,14 @@ func (h *Headscale) PollNetMapStream(
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Msg("Closing pollData channel") Msg("Closing pollData channel")
close(pollDataChan) close(pollDataChan)
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", m.Name). Str("machine", machine.Name).
Str("channel", "Done"). Str("channel", "Done").
Msg("Closing keepAliveChan channel") Msg("Closing keepAliveChan channel")
close(keepAliveChan) close(keepAliveChan)
@ -484,9 +491,9 @@ func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{}, cancelChan <-chan struct{},
updateChan chan<- struct{}, updateChan chan<- struct{},
keepAliveChan chan<- []byte, keepAliveChan chan<- []byte,
mKey wgkey.Key, machineKey wgkey.Key,
req tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
m *Machine, machine *Machine,
) { ) {
keepAliveTicker := time.NewTicker(KEEP_ALIVE_INTERVAL) keepAliveTicker := time.NewTicker(KEEP_ALIVE_INTERVAL)
updateCheckerTicker := time.NewTicker(UPDATE_CHECK_INTERVAL) updateCheckerTicker := time.NewTicker(UPDATE_CHECK_INTERVAL)
@ -497,7 +504,7 @@ func (h *Headscale) scheduledPollWorker(
return return
case <-keepAliveTicker.C: case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponse(mKey, req) data, err := h.getMapKeepAliveResponse(machineKey, mapRequest)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "keepAlive"). Str("func", "keepAlive").
@ -509,16 +516,16 @@ func (h *Headscale) scheduledPollWorker(
log.Debug(). log.Debug().
Str("func", "keepAlive"). Str("func", "keepAlive").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Sending keepalive") Msg("Sending keepalive")
keepAliveChan <- data keepAliveChan <- data
case <-updateCheckerTicker.C: case <-updateCheckerTicker.C:
log.Debug(). log.Debug().
Str("func", "scheduledPollWorker"). Str("func", "scheduledPollWorker").
Str("machine", m.Name). Str("machine", machine.Name).
Msg("Sending update request") Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update"). updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "scheduled-update").
Inc() Inc()
updateChan <- struct{}{} updateChan <- struct{}{}
} }

View file

@ -39,7 +39,7 @@ func (h *Headscale) CreatePreAuthKey(
ephemeral bool, ephemeral bool,
expiration *time.Time, expiration *time.Time,
) (*PreAuthKey, error) { ) (*PreAuthKey, error) {
n, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -50,29 +50,29 @@ func (h *Headscale) CreatePreAuthKey(
return nil, err return nil, err
} }
k := PreAuthKey{ key := PreAuthKey{
Key: kstr, Key: kstr,
NamespaceID: n.ID, NamespaceID: namespace.ID,
Namespace: *n, Namespace: *namespace,
Reusable: reusable, Reusable: reusable,
Ephemeral: ephemeral, Ephemeral: ephemeral,
CreatedAt: &now, CreatedAt: &now,
Expiration: expiration, Expiration: expiration,
} }
h.db.Save(&k) h.db.Save(&key)
return &k, nil return &key, nil
} }
// ListPreAuthKeys returns the list of PreAuthKeys for a namespace. // ListPreAuthKeys returns the list of PreAuthKeys for a namespace.
func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) { func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) {
n, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
keys := []PreAuthKey{} keys := []PreAuthKey{}
if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil {
return nil, err return nil, err
} }

View file

@ -15,12 +15,12 @@ func (h *Headscale) GetAdvertisedNodeRoutes(
namespace string, namespace string,
nodeName string, nodeName string,
) (*[]netaddr.IPPrefix, error) { ) (*[]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName) machine, err := h.GetMachine(namespace, nodeName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hostInfo, err := m.GetHostInfo() hostInfo, err := machine.GetHostInfo()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -35,12 +35,12 @@ func (h *Headscale) GetEnabledNodeRoutes(
namespace string, namespace string,
nodeName string, nodeName string,
) ([]netaddr.IPPrefix, error) { ) ([]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName) machine, err := h.GetMachine(namespace, nodeName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
data, err := m.EnabledRoutes.MarshalJSON() data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -97,7 +97,7 @@ func (h *Headscale) EnableNodeRoute(
nodeName string, nodeName string,
routeStr string, routeStr string,
) error { ) error {
m, err := h.GetMachine(namespace, nodeName) machine, err := h.GetMachine(namespace, nodeName)
if err != nil { if err != nil {
return err return err
} }
@ -137,10 +137,10 @@ func (h *Headscale) EnableNodeRoute(
return err return err
} }
m.EnabledRoutes = datatypes.JSON(routes) machine.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m) h.db.Save(&machine)
err = h.RequestMapUpdates(m.NamespaceID) err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil { if err != nil {
return err return err
} }