metrics, tuning in tests, db cleanups, fix concurrency issue (#1895)
This commit is contained in:
parent
7d8178406d
commit
ba614a5e6c
28 changed files with 328 additions and 201 deletions
|
@ -225,7 +225,7 @@ func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
var removed []types.NodeID
|
var removed []types.NodeID
|
||||||
var changed []types.NodeID
|
var changed []types.NodeID
|
||||||
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
if err := h.db.Write(func(tx *gorm.DB) error {
|
||||||
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
|
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -263,7 +263,7 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
|
||||||
var changed bool
|
var changed bool
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
if err := h.db.Write(func(tx *gorm.DB) error {
|
||||||
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
|
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -452,6 +452,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
|
||||||
|
|
||||||
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
|
router.Use(prometheusMiddleware)
|
||||||
router.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux)
|
router.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux)
|
||||||
|
|
||||||
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost)
|
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost)
|
||||||
|
@ -508,7 +509,7 @@ func (h *Headscale) Serve() error {
|
||||||
|
|
||||||
// Fetch an initial DERP Map before we start serving
|
// Fetch an initial DERP Map before we start serving
|
||||||
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
||||||
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap())
|
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier)
|
||||||
|
|
||||||
if h.cfg.DERP.ServerEnabled {
|
if h.cfg.DERP.ServerEnabled {
|
||||||
// When embedded DERP is enabled we always need a STUN server
|
// When embedded DERP is enabled we always need a STUN server
|
||||||
|
|
|
@ -273,8 +273,6 @@ func (h *Headscale) handleAuthKey(
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
|
||||||
Inc()
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -294,13 +292,6 @@ func (h *Headscale) handleAuthKey(
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
Str("node", registerRequest.Hostinfo.Hostname).
|
||||||
Msg("Failed authentication via AuthKey")
|
Msg("Failed authentication via AuthKey")
|
||||||
|
|
||||||
if pak != nil {
|
|
||||||
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
|
||||||
Inc()
|
|
||||||
} else {
|
|
||||||
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc()
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -404,15 +395,13 @@ func (h *Headscale) handleAuthKey(
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("could not register node")
|
Msg("could not register node")
|
||||||
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
|
||||||
Inc()
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.db.DB.Transaction(func(tx *gorm.DB) error {
|
h.db.Write(func(tx *gorm.DB) error {
|
||||||
return db.UsePreAuthKey(tx, pak)
|
return db.UsePreAuthKey(tx, pak)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -420,8 +409,6 @@ func (h *Headscale) handleAuthKey(
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Failed to use pre-auth key")
|
Msg("Failed to use pre-auth key")
|
||||||
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
|
||||||
Inc()
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -440,14 +427,10 @@ func (h *Headscale) handleAuthKey(
|
||||||
Str("node", registerRequest.Hostinfo.Hostname).
|
Str("node", registerRequest.Hostinfo.Hostname).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
|
||||||
Inc()
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name).
|
|
||||||
Inc()
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
_, err = writer.Write(respBody)
|
_, err = writer.Write(respBody)
|
||||||
|
@ -563,7 +546,7 @@ func (h *Headscale) handleNodeLogOut(
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.IsEphemeral() {
|
if node.IsEphemeral() {
|
||||||
changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap())
|
changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.LikelyConnectedMap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Err(err).
|
Err(err).
|
||||||
|
@ -616,14 +599,10 @@ func (h *Headscale) handleNodeWithValidRegistration(
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
nodeRegistrations.WithLabelValues("update", "web", "error", node.User.Name).
|
|
||||||
Inc()
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nodeRegistrations.WithLabelValues("update", "web", "success", node.User.Name).
|
|
||||||
Inc()
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
@ -654,7 +633,7 @@ func (h *Headscale) handleNodeKeyRefresh(
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||||
|
|
||||||
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
err := h.db.Write(func(tx *gorm.DB) error {
|
||||||
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
|
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -737,14 +716,10 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot encode message")
|
Msg("Cannot encode message")
|
||||||
nodeRegistrations.WithLabelValues("reauth", "web", "error", node.User.Name).
|
|
||||||
Inc()
|
|
||||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nodeRegistrations.WithLabelValues("reauth", "web", "success", node.User.Name).
|
|
||||||
Inc()
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
|
|
|
@ -33,7 +33,6 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||||
Caller().
|
Caller().
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot parse RegisterRequest")
|
Msg("Cannot parse RegisterRequest")
|
||||||
nodeRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
|
|
||||||
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
http.Error(writer, "Internal error", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
@ -260,9 +261,9 @@ func NodeSetExpiry(tx *gorm.DB,
|
||||||
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
|
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
|
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return DeleteNode(tx, node, isConnected)
|
return DeleteNode(tx, node, isLikelyConnected)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -270,9 +271,9 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConne
|
||||||
// Caller is responsible for notifying all of change.
|
// Caller is responsible for notifying all of change.
|
||||||
func DeleteNode(tx *gorm.DB,
|
func DeleteNode(tx *gorm.DB,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
isConnected types.NodeConnectedMap,
|
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||||
) ([]types.NodeID, error) {
|
) ([]types.NodeID, error) {
|
||||||
changed, err := deleteNodeRoutes(tx, node, isConnected)
|
changed, err := deleteNodeRoutes(tx, node, isLikelyConnected)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return changed, err
|
return changed, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
@ -120,7 +121,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||||
}
|
}
|
||||||
db.DB.Save(&node)
|
db.DB.Save(&node)
|
||||||
|
|
||||||
_, err = db.DeleteNode(&node, types.NodeConnectedMap{})
|
_, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(user.Name, "testnode3")
|
_, err = db.getNode(user.Name, "testnode3")
|
||||||
|
|
|
@ -147,7 +147,7 @@ func (*Suite) TestEphemeralKeyReusable(c *check.C) {
|
||||||
_, err = db.getNode("test7", "testest")
|
_, err = db.getNode("test7", "testest")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db.DB.Transaction(func(tx *gorm.DB) error {
|
db.Write(func(tx *gorm.DB) error {
|
||||||
DeleteExpiredEphemeralNodes(tx, time.Second*20)
|
DeleteExpiredEphemeralNodes(tx, time.Second*20)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -181,7 +181,7 @@ func (*Suite) TestEphemeralKeyNotReusable(c *check.C) {
|
||||||
_, err = db.getNode("test7", "testest")
|
_, err = db.getNode("test7", "testest")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db.DB.Transaction(func(tx *gorm.DB) error {
|
db.Write(func(tx *gorm.DB) error {
|
||||||
DeleteExpiredEphemeralNodes(tx, time.Second*20)
|
DeleteExpiredEphemeralNodes(tx, time.Second*20)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/util/set"
|
"tailscale.com/util/set"
|
||||||
|
@ -126,7 +127,7 @@ func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
|
||||||
|
|
||||||
func DisableRoute(tx *gorm.DB,
|
func DisableRoute(tx *gorm.DB,
|
||||||
id uint64,
|
id uint64,
|
||||||
isConnected types.NodeConnectedMap,
|
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||||
) ([]types.NodeID, error) {
|
) ([]types.NodeID, error) {
|
||||||
route, err := GetRoute(tx, id)
|
route, err := GetRoute(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -147,7 +148,7 @@ func DisableRoute(tx *gorm.DB,
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
update, err = failoverRouteTx(tx, isConnected, route)
|
update, err = failoverRouteTx(tx, isLikelyConnected, route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -182,17 +183,17 @@ func DisableRoute(tx *gorm.DB,
|
||||||
|
|
||||||
func (hsdb *HSDatabase) DeleteRoute(
|
func (hsdb *HSDatabase) DeleteRoute(
|
||||||
id uint64,
|
id uint64,
|
||||||
isConnected types.NodeConnectedMap,
|
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||||
) ([]types.NodeID, error) {
|
) ([]types.NodeID, error) {
|
||||||
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return DeleteRoute(tx, id, isConnected)
|
return DeleteRoute(tx, id, isLikelyConnected)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteRoute(
|
func DeleteRoute(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
id uint64,
|
id uint64,
|
||||||
isConnected types.NodeConnectedMap,
|
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||||
) ([]types.NodeID, error) {
|
) ([]types.NodeID, error) {
|
||||||
route, err := GetRoute(tx, id)
|
route, err := GetRoute(tx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -207,7 +208,7 @@ func DeleteRoute(
|
||||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||||
var update []types.NodeID
|
var update []types.NodeID
|
||||||
if !route.IsExitRoute() {
|
if !route.IsExitRoute() {
|
||||||
update, err = failoverRouteTx(tx, isConnected, route)
|
update, err = failoverRouteTx(tx, isLikelyConnected, route)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -252,7 +253,7 @@ func DeleteRoute(
|
||||||
return update, nil
|
return update, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) {
|
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
|
||||||
routes, err := GetNodeRoutes(tx, node)
|
routes, err := GetNodeRoutes(tx, node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting node routes: %w", err)
|
return nil, fmt.Errorf("getting node routes: %w", err)
|
||||||
|
@ -266,7 +267,7 @@ func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConne
|
||||||
|
|
||||||
// TODO(kradalby): This is a bit too aggressive, we could probably
|
// TODO(kradalby): This is a bit too aggressive, we could probably
|
||||||
// figure out which routes needs to be failed over rather than all.
|
// figure out which routes needs to be failed over rather than all.
|
||||||
chn, err := failoverRouteTx(tx, isConnected, &routes[i])
|
chn, err := failoverRouteTx(tx, isLikelyConnected, &routes[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return changed, fmt.Errorf("failing over route after delete: %w", err)
|
return changed, fmt.Errorf("failing over route after delete: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -409,7 +410,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
|
||||||
// If needed, the failover will be attempted.
|
// If needed, the failover will be attempted.
|
||||||
func FailoverNodeRoutesIfNeccessary(
|
func FailoverNodeRoutesIfNeccessary(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
isConnected types.NodeConnectedMap,
|
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) (*types.StateUpdate, error) {
|
) (*types.StateUpdate, error) {
|
||||||
nodeRoutes, err := GetNodeRoutes(tx, node)
|
nodeRoutes, err := GetNodeRoutes(tx, node)
|
||||||
|
@ -430,12 +431,12 @@ nodeRouteLoop:
|
||||||
if route.IsPrimary {
|
if route.IsPrimary {
|
||||||
// if we have a primary route, and the node is connected
|
// if we have a primary route, and the node is connected
|
||||||
// nothing needs to be done.
|
// nothing needs to be done.
|
||||||
if conn, ok := isConnected[route.Node.ID]; conn && ok {
|
if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val {
|
||||||
continue nodeRouteLoop
|
continue nodeRouteLoop
|
||||||
}
|
}
|
||||||
|
|
||||||
// if not, we need to failover the route
|
// if not, we need to failover the route
|
||||||
failover := failoverRoute(isConnected, &route, routes)
|
failover := failoverRoute(isLikelyConnected, &route, routes)
|
||||||
if failover != nil {
|
if failover != nil {
|
||||||
err := failover.save(tx)
|
err := failover.save(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -477,7 +478,7 @@ nodeRouteLoop:
|
||||||
// If the given route was not primary, it returns early.
|
// If the given route was not primary, it returns early.
|
||||||
func failoverRouteTx(
|
func failoverRouteTx(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
isConnected types.NodeConnectedMap,
|
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||||
r *types.Route,
|
r *types.Route,
|
||||||
) ([]types.NodeID, error) {
|
) ([]types.NodeID, error) {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
|
@ -500,7 +501,7 @@ func failoverRouteTx(
|
||||||
return nil, fmt.Errorf("getting routes by prefix: %w", err)
|
return nil, fmt.Errorf("getting routes by prefix: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fo := failoverRoute(isConnected, r, routes)
|
fo := failoverRoute(isLikelyConnected, r, routes)
|
||||||
if fo == nil {
|
if fo == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -538,7 +539,7 @@ func (f *failover) save(tx *gorm.DB) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func failoverRoute(
|
func failoverRoute(
|
||||||
isConnected types.NodeConnectedMap,
|
isLikelyConnected *xsync.MapOf[types.NodeID, bool],
|
||||||
routeToReplace *types.Route,
|
routeToReplace *types.Route,
|
||||||
altRoutes types.Routes,
|
altRoutes types.Routes,
|
||||||
|
|
||||||
|
@ -570,11 +571,13 @@ func failoverRoute(
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if isConnected != nil && isConnected[route.Node.ID] {
|
if isLikelyConnected != nil {
|
||||||
|
if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val {
|
||||||
newPrimary = &altRoutes[idx]
|
newPrimary = &altRoutes[idx]
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If a new route was not found/available,
|
// If a new route was not found/available,
|
||||||
// return without an error.
|
// return without an error.
|
||||||
|
|
|
@ -10,11 +10,22 @@ import (
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var smap = func(m map[types.NodeID]bool) *xsync.MapOf[types.NodeID, bool] {
|
||||||
|
s := xsync.NewMapOf[types.NodeID, bool]()
|
||||||
|
|
||||||
|
for k, v := range m {
|
||||||
|
s.Store(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
user, err := db.CreateUser("test")
|
user, err := db.CreateUser("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -331,7 +342,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
name string
|
name string
|
||||||
nodes types.Nodes
|
nodes types.Nodes
|
||||||
routes types.Routes
|
routes types.Routes
|
||||||
isConnected []types.NodeConnectedMap
|
isConnected []map[types.NodeID]bool
|
||||||
want []*types.StateUpdate
|
want []*types.StateUpdate
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
|
@ -346,7 +357,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: []types.NodeConnectedMap{
|
isConnected: []map[types.NodeID]bool{
|
||||||
// n1 goes down
|
// n1 goes down
|
||||||
{
|
{
|
||||||
1: false,
|
1: false,
|
||||||
|
@ -384,7 +395,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: []types.NodeConnectedMap{
|
isConnected: []map[types.NodeID]bool{
|
||||||
// n1 up recon = noop
|
// n1 up recon = noop
|
||||||
{
|
{
|
||||||
1: true,
|
1: true,
|
||||||
|
@ -428,7 +439,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: []types.NodeConnectedMap{
|
isConnected: []map[types.NodeID]bool{
|
||||||
// n1 goes down
|
// n1 goes down
|
||||||
{
|
{
|
||||||
1: false,
|
1: false,
|
||||||
|
@ -486,7 +497,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
r(2, 2, ipp("10.0.0.0/24"), false, false),
|
r(2, 2, ipp("10.0.0.0/24"), false, false),
|
||||||
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: []types.NodeConnectedMap{
|
isConnected: []map[types.NodeID]bool{
|
||||||
// n1 goes down
|
// n1 goes down
|
||||||
{
|
{
|
||||||
1: false,
|
1: false,
|
||||||
|
@ -516,7 +527,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: []types.NodeConnectedMap{
|
isConnected: []map[types.NodeID]bool{
|
||||||
// n1 goes down
|
// n1 goes down
|
||||||
{
|
{
|
||||||
1: false,
|
1: false,
|
||||||
|
@ -539,7 +550,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
r(3, 3, ipp("10.1.0.0/24"), false, false),
|
r(3, 3, ipp("10.1.0.0/24"), false, false),
|
||||||
},
|
},
|
||||||
isConnected: []types.NodeConnectedMap{
|
isConnected: []map[types.NodeID]bool{
|
||||||
// n1 goes down
|
// n1 goes down
|
||||||
{
|
{
|
||||||
1: false,
|
1: false,
|
||||||
|
@ -562,7 +573,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: []types.NodeConnectedMap{
|
isConnected: []map[types.NodeID]bool{
|
||||||
// n1 goes down
|
// n1 goes down
|
||||||
{
|
{
|
||||||
1: false,
|
1: false,
|
||||||
|
@ -585,7 +596,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, true),
|
r(2, 2, ipp("10.0.0.0/24"), true, true),
|
||||||
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
r(3, 3, ipp("10.1.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: []types.NodeConnectedMap{
|
isConnected: []map[types.NodeID]bool{
|
||||||
// n1 goes down
|
// n1 goes down
|
||||||
{
|
{
|
||||||
1: true,
|
1: true,
|
||||||
|
@ -618,7 +629,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
|
||||||
want := tt.want[step]
|
want := tt.want[step]
|
||||||
|
|
||||||
got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
return FailoverNodeRoutesIfNeccessary(tx, isConnected, node)
|
return FailoverNodeRoutesIfNeccessary(tx, smap(isConnected), node)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
|
@ -640,7 +651,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
||||||
name string
|
name string
|
||||||
failingRoute types.Route
|
failingRoute types.Route
|
||||||
routes types.Routes
|
routes types.Routes
|
||||||
isConnected types.NodeConnectedMap
|
isConnected map[types.NodeID]bool
|
||||||
want []types.NodeID
|
want []types.NodeID
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
|
@ -743,7 +754,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
isConnected: types.NodeConnectedMap{
|
isConnected: map[types.NodeID]bool{
|
||||||
1: false,
|
1: false,
|
||||||
2: true,
|
2: true,
|
||||||
},
|
},
|
||||||
|
@ -841,7 +852,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
isConnected: types.NodeConnectedMap{
|
isConnected: map[types.NodeID]bool{
|
||||||
1: true,
|
1: true,
|
||||||
2: true,
|
2: true,
|
||||||
3: true,
|
3: true,
|
||||||
|
@ -889,7 +900,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
isConnected: types.NodeConnectedMap{
|
isConnected: map[types.NodeID]bool{
|
||||||
1: true,
|
1: true,
|
||||||
4: false,
|
4: false,
|
||||||
},
|
},
|
||||||
|
@ -945,7 +956,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
isConnected: types.NodeConnectedMap{
|
isConnected: map[types.NodeID]bool{
|
||||||
1: false,
|
1: false,
|
||||||
2: true,
|
2: true,
|
||||||
4: false,
|
4: false,
|
||||||
|
@ -1010,7 +1021,7 @@ func TestFailoverRouteTx(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
got, err := Write(db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
got, err := Write(db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return failoverRouteTx(tx, tt.isConnected, &tt.failingRoute)
|
return failoverRouteTx(tx, smap(tt.isConnected), &tt.failingRoute)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
|
@ -1048,7 +1059,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
name string
|
name string
|
||||||
failingRoute types.Route
|
failingRoute types.Route
|
||||||
routes types.Routes
|
routes types.Routes
|
||||||
isConnected types.NodeConnectedMap
|
isConnected map[types.NodeID]bool
|
||||||
want *failover
|
want *failover
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
@ -1085,7 +1096,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
r(2, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: types.NodeConnectedMap{
|
isConnected: map[types.NodeID]bool{
|
||||||
1: false,
|
1: false,
|
||||||
2: true,
|
2: true,
|
||||||
},
|
},
|
||||||
|
@ -1111,7 +1122,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
r(2, 2, ipp("10.0.0.0/24"), true, true),
|
r(2, 2, ipp("10.0.0.0/24"), true, true),
|
||||||
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
r(3, 3, ipp("10.0.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: types.NodeConnectedMap{
|
isConnected: map[types.NodeID]bool{
|
||||||
1: true,
|
1: true,
|
||||||
2: true,
|
2: true,
|
||||||
3: true,
|
3: true,
|
||||||
|
@ -1128,7 +1139,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
r(1, 1, ipp("10.0.0.0/24"), true, true),
|
||||||
r(2, 4, ipp("10.0.0.0/24"), true, false),
|
r(2, 4, ipp("10.0.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: types.NodeConnectedMap{
|
isConnected: map[types.NodeID]bool{
|
||||||
1: true,
|
1: true,
|
||||||
4: false,
|
4: false,
|
||||||
},
|
},
|
||||||
|
@ -1142,7 +1153,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
r(2, 4, ipp("10.0.0.0/24"), true, false),
|
r(2, 4, ipp("10.0.0.0/24"), true, false),
|
||||||
r(3, 2, ipp("10.0.0.0/24"), true, false),
|
r(3, 2, ipp("10.0.0.0/24"), true, false),
|
||||||
},
|
},
|
||||||
isConnected: types.NodeConnectedMap{
|
isConnected: map[types.NodeID]bool{
|
||||||
1: false,
|
1: false,
|
||||||
2: true,
|
2: true,
|
||||||
4: false,
|
4: false,
|
||||||
|
@ -1172,7 +1183,7 @@ func TestFailoverRoute(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
gotf := failoverRoute(tt.isConnected, &tt.failingRoute, tt.routes)
|
gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes)
|
||||||
|
|
||||||
if tt.want == nil && gotf != nil {
|
if tt.want == nil && gotf != nil {
|
||||||
t.Fatalf("expected nil, got %+v", gotf)
|
t.Fatalf("expected nil, got %+v", gotf)
|
||||||
|
|
|
@ -145,7 +145,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ExpirePreAuthKeyRequest,
|
request *v1.ExpirePreAuthKeyRequest,
|
||||||
) (*v1.ExpirePreAuthKeyResponse, error) {
|
) (*v1.ExpirePreAuthKeyResponse, error) {
|
||||||
err := api.h.db.DB.Transaction(func(tx *gorm.DB) error {
|
err := api.h.db.Write(func(tx *gorm.DB) error {
|
||||||
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
|
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -301,7 +301,7 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||||
|
|
||||||
changedNodes, err := api.h.db.DeleteNode(
|
changedNodes, err := api.h.db.DeleteNode(
|
||||||
node,
|
node,
|
||||||
api.h.nodeNotifier.ConnectedMap(),
|
api.h.nodeNotifier.LikelyConnectedMap(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -343,7 +343,7 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
||||||
api.h.nodeNotifier.NotifyByMachineKey(
|
api.h.nodeNotifier.NotifyByNodeID(
|
||||||
ctx,
|
ctx,
|
||||||
types.StateUpdate{
|
types.StateUpdate{
|
||||||
Type: types.StateSelfUpdate,
|
Type: types.StateSelfUpdate,
|
||||||
|
@ -401,7 +401,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.ListNodesRequest,
|
request *v1.ListNodesRequest,
|
||||||
) (*v1.ListNodesResponse, error) {
|
) (*v1.ListNodesResponse, error) {
|
||||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||||
if request.GetUser() != "" {
|
if request.GetUser() != "" {
|
||||||
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||||
return db.ListNodesByUser(rx, request.GetUser())
|
return db.ListNodesByUser(rx, request.GetUser())
|
||||||
|
@ -416,7 +416,9 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
resp.Online = isConnected[node.ID]
|
if val, ok := isLikelyConnected.Load(node.ID); ok && val {
|
||||||
|
resp.Online = true
|
||||||
|
}
|
||||||
|
|
||||||
response[index] = resp
|
response[index] = resp
|
||||||
}
|
}
|
||||||
|
@ -439,7 +441,9 @@ func (api headscaleV1APIServer) ListNodes(
|
||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
resp.Online = isConnected[node.ID]
|
if val, ok := isLikelyConnected.Load(node.ID); ok && val {
|
||||||
|
resp.Online = true
|
||||||
|
}
|
||||||
|
|
||||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
||||||
node,
|
node,
|
||||||
|
@ -528,7 +532,7 @@ func (api headscaleV1APIServer) DisableRoute(
|
||||||
request *v1.DisableRouteRequest,
|
request *v1.DisableRouteRequest,
|
||||||
) (*v1.DisableRouteResponse, error) {
|
) (*v1.DisableRouteResponse, error) {
|
||||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.ConnectedMap())
|
return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.LikelyConnectedMap())
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -568,7 +572,7 @@ func (api headscaleV1APIServer) DeleteRoute(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.DeleteRouteRequest,
|
request *v1.DeleteRouteRequest,
|
||||||
) (*v1.DeleteRouteResponse, error) {
|
) (*v1.DeleteRouteResponse, error) {
|
||||||
isConnected := api.h.nodeNotifier.ConnectedMap()
|
isConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||||
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
|
||||||
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
|
return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
|
||||||
})
|
})
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
|
|
||||||
mapset "github.com/deckarep/golang-set/v2"
|
mapset "github.com/deckarep/golang-set/v2"
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
@ -54,7 +55,7 @@ type Mapper struct {
|
||||||
db *db.HSDatabase
|
db *db.HSDatabase
|
||||||
cfg *types.Config
|
cfg *types.Config
|
||||||
derpMap *tailcfg.DERPMap
|
derpMap *tailcfg.DERPMap
|
||||||
isLikelyConnected types.NodeConnectedMap
|
notif *notifier.Notifier
|
||||||
|
|
||||||
uid string
|
uid string
|
||||||
created time.Time
|
created time.Time
|
||||||
|
@ -70,7 +71,7 @@ func NewMapper(
|
||||||
db *db.HSDatabase,
|
db *db.HSDatabase,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
derpMap *tailcfg.DERPMap,
|
derpMap *tailcfg.DERPMap,
|
||||||
isLikelyConnected types.NodeConnectedMap,
|
notif *notifier.Notifier,
|
||||||
) *Mapper {
|
) *Mapper {
|
||||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||||
|
|
||||||
|
@ -78,7 +79,7 @@ func NewMapper(
|
||||||
db: db,
|
db: db,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
derpMap: derpMap,
|
derpMap: derpMap,
|
||||||
isLikelyConnected: isLikelyConnected,
|
notif: notif,
|
||||||
|
|
||||||
uid: uid,
|
uid: uid,
|
||||||
created: time.Now(),
|
created: time.Now(),
|
||||||
|
@ -517,7 +518,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
online := m.isLikelyConnected[peer.ID]
|
online := m.notif.IsLikelyConnected(peer.ID)
|
||||||
peer.IsOnline = &online
|
peer.IsOnline = &online
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package hscontrol
|
package hscontrol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
)
|
)
|
||||||
|
@ -8,18 +12,94 @@ import (
|
||||||
const prometheusNamespace = "headscale"
|
const prometheusNamespace = "headscale"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// This is a high cardinality metric (user x node), we might want to make this
|
mapResponseSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
// configurable/opt-in in the future.
|
|
||||||
nodeRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
Namespace: prometheusNamespace,
|
||||||
Name: "node_registrations_total",
|
Name: "mapresponse_sent_total",
|
||||||
Help: "The total amount of registered node attempts",
|
Help: "total count of mapresponses sent to clients",
|
||||||
}, []string{"action", "auth", "status", "user"})
|
}, []string{"status", "type"})
|
||||||
|
mapResponseUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
Namespace: prometheusNamespace,
|
||||||
Name: "update_request_sent_to_node_total",
|
Name: "mapresponse_updates_received_total",
|
||||||
Help: "The number of calls/messages issued on a specific nodes update channel",
|
Help: "total count of mapresponse updates received on update channel",
|
||||||
}, []string{"user", "node", "status"})
|
}, []string{"type"})
|
||||||
// TODO(kradalby): This is very debugging, we might want to remove it.
|
mapResponseWriteUpdatesInStream = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "mapresponse_write_updates_in_stream_total",
|
||||||
|
Help: "total count of writes that occured in a stream session, pre-68 nodes",
|
||||||
|
}, []string{"status"})
|
||||||
|
mapResponseEndpointUpdates = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "mapresponse_endpoint_updates_total",
|
||||||
|
Help: "total count of endpoint updates received",
|
||||||
|
}, []string{"status"})
|
||||||
|
mapResponseReadOnly = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "mapresponse_readonly_requests_total",
|
||||||
|
Help: "total count of readonly requests received",
|
||||||
|
}, []string{"status"})
|
||||||
|
mapResponseSessions = promauto.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "mapresponse_current_sessions_total",
|
||||||
|
Help: "total count open map response sessions",
|
||||||
|
})
|
||||||
|
mapResponseRejected = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "mapresponse_rejected_new_sessions_total",
|
||||||
|
Help: "total count of new mapsessions rejected",
|
||||||
|
}, []string{"reason"})
|
||||||
|
httpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "http_duration_seconds",
|
||||||
|
Help: "Duration of HTTP requests.",
|
||||||
|
}, []string{"path"})
|
||||||
|
httpCounter = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "http_requests_total",
|
||||||
|
Help: "Total number of http requests processed",
|
||||||
|
}, []string{"code", "method", "path"},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// prometheusMiddleware implements mux.MiddlewareFunc.
|
||||||
|
func prometheusMiddleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
route := mux.CurrentRoute(r)
|
||||||
|
path, _ := route.GetPathTemplate()
|
||||||
|
|
||||||
|
// Ignore streaming and noise sessions
|
||||||
|
// it has its own router further down.
|
||||||
|
if path == "/ts2021" || path == "/machine/map" || path == "/derp" || path == "/derp/probe" || path == "/bootstrap-dns" {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := &respWriterProm{ResponseWriter: w}
|
||||||
|
|
||||||
|
timer := prometheus.NewTimer(httpDuration.WithLabelValues(path))
|
||||||
|
next.ServeHTTP(rw, r)
|
||||||
|
timer.ObserveDuration()
|
||||||
|
httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type respWriterProm struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
written int64
|
||||||
|
wroteHeader bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *respWriterProm) WriteHeader(code int) {
|
||||||
|
r.status = code
|
||||||
|
r.wroteHeader = true
|
||||||
|
r.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *respWriterProm) Write(b []byte) (int, error) {
|
||||||
|
if !r.wroteHeader {
|
||||||
|
r.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
n, err := r.ResponseWriter.Write(b)
|
||||||
|
r.written += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
|
@ -95,6 +95,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
||||||
// The HTTP2 server that exposes this router is created for
|
// The HTTP2 server that exposes this router is created for
|
||||||
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
|
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
|
router.Use(prometheusMiddleware)
|
||||||
|
|
||||||
router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler).
|
router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler).
|
||||||
Methods(http.MethodPost)
|
Methods(http.MethodPost)
|
||||||
|
@ -267,10 +268,12 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||||
defer ns.headscale.mapSessionMu.Unlock()
|
defer ns.headscale.mapSessionMu.Unlock()
|
||||||
|
|
||||||
sess.infof("node has an open stream(%p), rejecting new stream", sess)
|
sess.infof("node has an open stream(%p), rejecting new stream", sess)
|
||||||
|
mapResponseRejected.WithLabelValues("exists").Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ns.headscale.mapSessions[node.ID] = sess
|
ns.headscale.mapSessions[node.ID] = sess
|
||||||
|
mapResponseSessions.Inc()
|
||||||
ns.headscale.mapSessionMu.Unlock()
|
ns.headscale.mapSessionMu.Unlock()
|
||||||
sess.tracef("releasing lock to check stream")
|
sess.tracef("releasing lock to check stream")
|
||||||
}
|
}
|
||||||
|
@ -283,6 +286,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||||
defer ns.headscale.mapSessionMu.Unlock()
|
defer ns.headscale.mapSessionMu.Unlock()
|
||||||
|
|
||||||
delete(ns.headscale.mapSessions, node.ID)
|
delete(ns.headscale.mapSessions, node.ID)
|
||||||
|
mapResponseSessions.Dec()
|
||||||
|
|
||||||
sess.tracef("releasing lock to remove stream")
|
sess.tracef("releasing lock to remove stream")
|
||||||
}
|
}
|
||||||
|
|
27
hscontrol/notifier/metrics.go
Normal file
27
hscontrol/notifier/metrics.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package notifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const prometheusNamespace = "headscale"
|
||||||
|
|
||||||
|
var (
|
||||||
|
notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "notifier_wait_for_lock_seconds",
|
||||||
|
Help: "histogram of time spent waiting for the notifier lock",
|
||||||
|
Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10},
|
||||||
|
}, []string{"action"})
|
||||||
|
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "notifier_update_sent_total",
|
||||||
|
Help: "total count of update sent on nodes channel",
|
||||||
|
}, []string{"status", "type"})
|
||||||
|
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Namespace: prometheusNamespace,
|
||||||
|
Name: "notifier_open_channels_total",
|
||||||
|
Help: "total count open channels in notifier",
|
||||||
|
})
|
||||||
|
)
|
|
@ -6,21 +6,23 @@ import (
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
l sync.RWMutex
|
l sync.RWMutex
|
||||||
nodes map[types.NodeID]chan<- types.StateUpdate
|
nodes map[types.NodeID]chan<- types.StateUpdate
|
||||||
connected types.NodeConnectedMap
|
connected *xsync.MapOf[types.NodeID, bool]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNotifier() *Notifier {
|
func NewNotifier() *Notifier {
|
||||||
return &Notifier{
|
return &Notifier{
|
||||||
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
||||||
connected: make(types.NodeConnectedMap),
|
connected: xsync.NewMapOf[types.NodeID, bool](),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,16 +33,19 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||||
Uint64("node.id", nodeID.Uint64()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Msg("releasing lock to add node")
|
Msg("releasing lock to add node")
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
n.l.Lock()
|
n.l.Lock()
|
||||||
defer n.l.Unlock()
|
defer n.l.Unlock()
|
||||||
|
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
|
||||||
|
|
||||||
n.nodes[nodeID] = c
|
n.nodes[nodeID] = c
|
||||||
n.connected[nodeID] = true
|
n.connected.Store(nodeID, true)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Uint64("node.id", nodeID.Uint64()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Int("open_chans", len(n.nodes)).
|
Int("open_chans", len(n.nodes)).
|
||||||
Msg("Added new channel")
|
Msg("Added new channel")
|
||||||
|
notifierNodeUpdateChans.Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) RemoveNode(nodeID types.NodeID) {
|
func (n *Notifier) RemoveNode(nodeID types.NodeID) {
|
||||||
|
@ -50,20 +55,23 @@ func (n *Notifier) RemoveNode(nodeID types.NodeID) {
|
||||||
Uint64("node.id", nodeID.Uint64()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Msg("releasing lock to remove node")
|
Msg("releasing lock to remove node")
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
n.l.Lock()
|
n.l.Lock()
|
||||||
defer n.l.Unlock()
|
defer n.l.Unlock()
|
||||||
|
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
|
||||||
|
|
||||||
if len(n.nodes) == 0 {
|
if len(n.nodes) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(n.nodes, nodeID)
|
delete(n.nodes, nodeID)
|
||||||
n.connected[nodeID] = false
|
n.connected.Store(nodeID, false)
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Uint64("node.id", nodeID.Uint64()).
|
Uint64("node.id", nodeID.Uint64()).
|
||||||
Int("open_chans", len(n.nodes)).
|
Int("open_chans", len(n.nodes)).
|
||||||
Msg("Removed channel")
|
Msg("Removed channel")
|
||||||
|
notifierNodeUpdateChans.Dec()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsConnected reports if a node is connected to headscale and has a
|
// IsConnected reports if a node is connected to headscale and has a
|
||||||
|
@ -72,17 +80,22 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
|
||||||
return n.connected[nodeID]
|
if val, ok := n.connected.Load(nodeID); ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsLikelyConnected reports if a node is connected to headscale and has a
|
// IsLikelyConnected reports if a node is connected to headscale and has a
|
||||||
// poll session open, but doesnt lock, so might be wrong.
|
// poll session open, but doesnt lock, so might be wrong.
|
||||||
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
||||||
return n.connected[nodeID]
|
if val, ok := n.connected.Load(nodeID); ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): This returns a pointer and can be dangerous.
|
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
||||||
func (n *Notifier) ConnectedMap() types.NodeConnectedMap {
|
|
||||||
return n.connected
|
return n.connected
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,45 +108,16 @@ func (n *Notifier) NotifyWithIgnore(
|
||||||
update types.StateUpdate,
|
update types.StateUpdate,
|
||||||
ignoreNodeIDs ...types.NodeID,
|
ignoreNodeIDs ...types.NodeID,
|
||||||
) {
|
) {
|
||||||
log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
|
for nodeID := range n.nodes {
|
||||||
defer log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("type", update.Type.String()).
|
|
||||||
Msg("releasing lock, finished notifying")
|
|
||||||
|
|
||||||
n.l.RLock()
|
|
||||||
defer n.l.RUnlock()
|
|
||||||
|
|
||||||
if update.Type == types.StatePeerChangedPatch {
|
|
||||||
log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT")
|
|
||||||
}
|
|
||||||
|
|
||||||
for nodeID, c := range n.nodes {
|
|
||||||
if slices.Contains(ignoreNodeIDs, nodeID) {
|
if slices.Contains(ignoreNodeIDs, nodeID) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
n.NotifyByNodeID(ctx, update, nodeID)
|
||||||
case <-ctx.Done():
|
|
||||||
log.Error().
|
|
||||||
Err(ctx.Err()).
|
|
||||||
Uint64("node.id", nodeID.Uint64()).
|
|
||||||
Any("origin", ctx.Value("origin")).
|
|
||||||
Any("origin-hostname", ctx.Value("hostname")).
|
|
||||||
Msgf("update not sent, context cancelled")
|
|
||||||
|
|
||||||
return
|
|
||||||
case c <- update:
|
|
||||||
log.Trace().
|
|
||||||
Uint64("node.id", nodeID.Uint64()).
|
|
||||||
Any("origin", ctx.Value("origin")).
|
|
||||||
Any("origin-hostname", ctx.Value("hostname")).
|
|
||||||
Msgf("update successfully sent on chan")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Notifier) NotifyByMachineKey(
|
func (n *Notifier) NotifyByNodeID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
update types.StateUpdate,
|
update types.StateUpdate,
|
||||||
nodeID types.NodeID,
|
nodeID types.NodeID,
|
||||||
|
@ -144,8 +128,10 @@ func (n *Notifier) NotifyByMachineKey(
|
||||||
Str("type", update.Type.String()).
|
Str("type", update.Type.String()).
|
||||||
Msg("releasing lock, finished notifying")
|
Msg("releasing lock, finished notifying")
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
n.l.RLock()
|
n.l.RLock()
|
||||||
defer n.l.RUnlock()
|
defer n.l.RUnlock()
|
||||||
|
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
|
||||||
|
|
||||||
if c, ok := n.nodes[nodeID]; ok {
|
if c, ok := n.nodes[nodeID]; ok {
|
||||||
select {
|
select {
|
||||||
|
@ -156,6 +142,7 @@ func (n *Notifier) NotifyByMachineKey(
|
||||||
Any("origin", ctx.Value("origin")).
|
Any("origin", ctx.Value("origin")).
|
||||||
Any("origin-hostname", ctx.Value("hostname")).
|
Any("origin-hostname", ctx.Value("hostname")).
|
||||||
Msgf("update not sent, context cancelled")
|
Msgf("update not sent, context cancelled")
|
||||||
|
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String()).Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
case c <- update:
|
case c <- update:
|
||||||
|
@ -164,6 +151,7 @@ func (n *Notifier) NotifyByMachineKey(
|
||||||
Any("origin", ctx.Value("origin")).
|
Any("origin", ctx.Value("origin")).
|
||||||
Any("origin-hostname", ctx.Value("hostname")).
|
Any("origin-hostname", ctx.Value("hostname")).
|
||||||
Msgf("update successfully sent on chan")
|
Msgf("update successfully sent on chan")
|
||||||
|
notifierUpdateSent.WithLabelValues("ok", update.Type.String()).Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -182,9 +170,10 @@ func (n *Notifier) String() string {
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
b.WriteString("connected:\n")
|
b.WriteString("connected:\n")
|
||||||
|
|
||||||
for k, v := range n.connected {
|
n.connected.Range(func(k types.NodeID, v bool) bool {
|
||||||
fmt.Fprintf(&b, "\t%d: %t\n", k, v)
|
fmt.Fprintf(&b, "\t%d: %t\n", k, v)
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
|
@ -602,7 +602,7 @@ func (h *Headscale) registerNodeForOIDCCallback(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
if err := h.db.Write(func(tx *gorm.DB) error {
|
||||||
if _, err := db.RegisterNodeFromAuthCallback(
|
if _, err := db.RegisterNodeFromAuthCallback(
|
||||||
// TODO(kradalby): find a better way to use the cache across modules
|
// TODO(kradalby): find a better way to use the cache across modules
|
||||||
tx,
|
tx,
|
||||||
|
|
|
@ -64,7 +64,7 @@ func (h *Headscale) newMapSession(
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) *mapSession {
|
) *mapSession {
|
||||||
warnf, tracef, infof, errf := logPollFunc(req, node)
|
warnf, infof, tracef, errf := logPollFunc(req, node)
|
||||||
|
|
||||||
// Use a buffered channel in case a node is not fully ready
|
// Use a buffered channel in case a node is not fully ready
|
||||||
// to receive a message to make sure we dont block the entire
|
// to receive a message to make sure we dont block the entire
|
||||||
|
@ -196,8 +196,10 @@ func (m *mapSession) serve() {
|
||||||
// return
|
// return
|
||||||
err := m.handleSaveNode()
|
err := m.handleSaveNode()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
mapResponseWriteUpdatesInStream.WithLabelValues("error").Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
mapResponseWriteUpdatesInStream.WithLabelValues("ok").Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up the client stream
|
// Set up the client stream
|
||||||
|
@ -284,6 +286,7 @@ func (m *mapSession) serve() {
|
||||||
patches = filteredPatches
|
patches = filteredPatches
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updateType := "full"
|
||||||
// When deciding what update to send, the following is considered,
|
// When deciding what update to send, the following is considered,
|
||||||
// Full is a superset of all updates, when a full update is requested,
|
// Full is a superset of all updates, when a full update is requested,
|
||||||
// send only that and move on, all other updates will be present in
|
// send only that and move on, all other updates will be present in
|
||||||
|
@ -303,12 +306,15 @@ func (m *mapSession) serve() {
|
||||||
} else if changed != nil {
|
} else if changed != nil {
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, patches, m.h.ACLPolicy, lastMessage)
|
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, patches, m.h.ACLPolicy, lastMessage)
|
||||||
|
updateType = "change"
|
||||||
} else if patches != nil {
|
} else if patches != nil {
|
||||||
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
||||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, patches, m.h.ACLPolicy)
|
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, patches, m.h.ACLPolicy)
|
||||||
|
updateType = "patch"
|
||||||
} else if derp {
|
} else if derp {
|
||||||
m.tracef("Sending DERPUpdate MapResponse")
|
m.tracef("Sending DERPUpdate MapResponse")
|
||||||
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap)
|
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap)
|
||||||
|
updateType = "derp"
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -324,19 +330,22 @@ func (m *mapSession) serve() {
|
||||||
startWrite := time.Now()
|
startWrite := time.Now()
|
||||||
_, err = m.w.Write(data)
|
_, err = m.w.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
||||||
m.errf(err, "Could not write the map response, for mapSession: %p", m)
|
m.errf(err, "Could not write the map response, for mapSession: %p", m)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rc.Flush()
|
err = rc.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
||||||
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
|
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
||||||
|
|
||||||
m.infof("update sent")
|
mapResponseSent.WithLabelValues("ok", updateType).Inc()
|
||||||
|
m.tracef("update sent")
|
||||||
}
|
}
|
||||||
|
|
||||||
// reset
|
// reset
|
||||||
|
@ -364,7 +373,8 @@ func (m *mapSession) serve() {
|
||||||
|
|
||||||
// Consume all updates sent to node
|
// Consume all updates sent to node
|
||||||
case update := <-m.ch:
|
case update := <-m.ch:
|
||||||
m.tracef("received stream update: %d %s", update.Type, update.Message)
|
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
|
||||||
|
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
|
||||||
|
|
||||||
switch update.Type {
|
switch update.Type {
|
||||||
case types.StateFullUpdate:
|
case types.StateFullUpdate:
|
||||||
|
@ -404,27 +414,30 @@ func (m *mapSession) serve() {
|
||||||
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
|
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Error generating the keep alive msg")
|
m.errf(err, "Error generating the keep alive msg")
|
||||||
|
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = m.w.Write(data)
|
_, err = m.w.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Cannot write keep alive message")
|
m.errf(err, "Cannot write keep alive message")
|
||||||
|
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = rc.Flush()
|
err = rc.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
|
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
|
||||||
|
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
|
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
|
||||||
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
|
||||||
return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.ConnectedMap(), node)
|
return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
|
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
|
||||||
|
@ -454,7 +467,7 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
|
||||||
node.LastSeen = &now
|
node.LastSeen = &now
|
||||||
change.LastSeen = &now
|
change.LastSeen = &now
|
||||||
|
|
||||||
err := h.db.DB.Transaction(func(tx *gorm.DB) error {
|
err := h.db.Write(func(tx *gorm.DB) error {
|
||||||
return db.SetLastSeen(tx, node.ID, *node.LastSeen)
|
return db.SetLastSeen(tx, node.ID, *node.LastSeen)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -501,6 +514,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||||
// If there is no changes and nothing to save,
|
// If there is no changes and nothing to save,
|
||||||
// return early.
|
// return early.
|
||||||
if peerChangeEmpty(change) && !sendUpdate {
|
if peerChangeEmpty(change) && !sendUpdate {
|
||||||
|
mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -518,6 +532,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Error processing node routes")
|
m.errf(err, "Error processing node routes")
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
http.Error(m.w, "", http.StatusInternalServerError)
|
||||||
|
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -527,6 +542,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||||
err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node)
|
err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Error running auto approved routes")
|
m.errf(err, "Error running auto approved routes")
|
||||||
|
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -534,19 +550,19 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||||
// has an updated packetfilter allowing the new route
|
// has an updated packetfilter allowing the new route
|
||||||
// if it is defined in the ACL.
|
// if it is defined in the ACL.
|
||||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
|
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
|
||||||
m.h.nodeNotifier.NotifyByMachineKey(
|
m.h.nodeNotifier.NotifyByNodeID(
|
||||||
ctx,
|
ctx,
|
||||||
types.StateUpdate{
|
types.StateUpdate{
|
||||||
Type: types.StateSelfUpdate,
|
Type: types.StateSelfUpdate,
|
||||||
ChangeNodes: []types.NodeID{m.node.ID},
|
ChangeNodes: []types.NodeID{m.node.ID},
|
||||||
},
|
},
|
||||||
m.node.ID)
|
m.node.ID)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.h.db.DB.Save(m.node).Error; err != nil {
|
if err := m.h.db.DB.Save(m.node).Error; err != nil {
|
||||||
m.errf(err, "Failed to persist/update node in the database")
|
m.errf(err, "Failed to persist/update node in the database")
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
http.Error(m.w, "", http.StatusInternalServerError)
|
||||||
|
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -562,6 +578,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||||
m.node.ID)
|
m.node.ID)
|
||||||
|
|
||||||
m.w.WriteHeader(http.StatusOK)
|
m.w.WriteHeader(http.StatusOK)
|
||||||
|
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -639,7 +656,7 @@ func (m *mapSession) handleReadOnlyRequest() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Failed to create MapResponse")
|
m.errf(err, "Failed to create MapResponse")
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
http.Error(m.w, "", http.StatusInternalServerError)
|
||||||
|
mapResponseReadOnly.WithLabelValues("error").Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -648,9 +665,12 @@ func (m *mapSession) handleReadOnlyRequest() {
|
||||||
_, err = m.w.Write(mapResp)
|
_, err = m.w.Write(mapResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Failed to write response")
|
m.errf(err, "Failed to write response")
|
||||||
|
mapResponseReadOnly.WithLabelValues("error").Inc()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.w.WriteHeader(http.StatusOK)
|
m.w.WriteHeader(http.StatusOK)
|
||||||
|
mapResponseReadOnly.WithLabelValues("ok").Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,8 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type NodeID uint64
|
type NodeID uint64
|
||||||
type NodeConnectedMap map[NodeID]bool
|
|
||||||
|
// type NodeConnectedMap *xsync.MapOf[NodeID, bool]
|
||||||
|
|
||||||
func (id NodeID) StableID() tailcfg.StableNodeID {
|
func (id NodeID) StableID() tailcfg.StableNodeID {
|
||||||
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
|
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
|
||||||
|
|
|
@ -51,7 +51,7 @@ func aclScenario(
|
||||||
clientsPerUser int,
|
clientsPerUser int,
|
||||||
) *Scenario {
|
) *Scenario {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
@ -264,7 +264,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||||
|
|
||||||
for name, testCase := range tests {
|
for name, testCase := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
spec := testCase.users
|
spec := testCase.users
|
||||||
|
|
|
@ -42,7 +42,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
baseScenario, err := NewScenario()
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
scenario := AuthOIDCScenario{
|
scenario := AuthOIDCScenario{
|
||||||
|
@ -100,7 +100,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
||||||
|
|
||||||
shortAccessTTL := 5 * time.Minute
|
shortAccessTTL := 5 * time.Minute
|
||||||
|
|
||||||
baseScenario, err := NewScenario()
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
baseScenario.pool.MaxWait = 5 * time.Minute
|
baseScenario.pool.MaxWait = 5 * time.Minute
|
||||||
|
|
|
@ -26,7 +26,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
baseScenario, err := NewScenario()
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create scenario: %s", err)
|
t.Fatalf("failed to create scenario: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
baseScenario, err := NewScenario()
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
scenario := AuthWebFlowScenario{
|
scenario := AuthWebFlowScenario{
|
||||||
|
|
|
@ -32,7 +32,7 @@ func TestUserCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||||
user := "preauthkeyspace"
|
user := "preauthkeyspace"
|
||||||
count := 3
|
count := 3
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -254,7 +254,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||||
|
|
||||||
user := "pre-auth-key-without-exp-user"
|
user := "pre-auth-key-without-exp-user"
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -317,7 +317,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||||
|
|
||||||
user := "pre-auth-key-reus-ephm-user"
|
user := "pre-auth-key-reus-ephm-user"
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -394,7 +394,7 @@ func TestApiKeyCommand(t *testing.T) {
|
||||||
|
|
||||||
count := 5
|
count := 5
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -562,7 +562,7 @@ func TestNodeTagCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -695,7 +695,7 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -745,7 +745,7 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -808,7 +808,7 @@ func TestNodeCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -1049,7 +1049,7 @@ func TestNodeExpireCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -1176,7 +1176,7 @@ func TestNodeRenameCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -1343,7 +1343,7 @@ func TestNodeMoveCommand(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ func TestDERPServerScenario(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
// t.Parallel()
|
// t.Parallel()
|
||||||
|
|
||||||
baseScenario, err := NewScenario()
|
baseScenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
scenario := EmbeddedDERPServerScenario{
|
scenario := EmbeddedDERPServerScenario{
|
||||||
|
|
|
@ -23,7 +23,7 @@ func TestPingAllByIP(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -216,7 +216,7 @@ func TestEphemeral(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -299,7 +299,7 @@ func TestPingAllByHostname(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -348,7 +348,7 @@ func TestTaildrop(t *testing.T) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -509,7 +509,7 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -577,7 +577,7 @@ func TestExpireNode(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -703,7 +703,7 @@ func TestNodeOnlineStatus(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -818,7 +818,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||||
IntegrationSkip(t)
|
IntegrationSkip(t)
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -201,6 +202,14 @@ func WithEmbeddedDERPServerOnly() Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithTuning allows changing the tuning settings easily.
|
||||||
|
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
|
||||||
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String()
|
||||||
|
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(mapSessionChanSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// New returns a new HeadscaleInContainer instance.
|
// New returns a new HeadscaleInContainer instance.
|
||||||
func New(
|
func New(
|
||||||
pool *dockertest.Pool,
|
pool *dockertest.Pool,
|
||||||
|
|
|
@ -28,7 +28,7 @@ func TestEnablingRoutes(t *testing.T) {
|
||||||
|
|
||||||
user := "enable-routing"
|
user := "enable-routing"
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||||
|
|
||||||
user := "enable-routing"
|
user := "enable-routing"
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||||
// defer scenario.Shutdown()
|
// defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -822,7 +822,7 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||||
|
|
||||||
user := "enable-disable-routing"
|
user := "enable-disable-routing"
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -966,7 +966,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||||
|
|
||||||
user := "subnet-route-acl"
|
user := "subnet-route-acl"
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErrf(t, "failed to create scenario: %s", err)
|
assertNoErrf(t, "failed to create scenario: %s", err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
@ -141,7 +142,7 @@ type Scenario struct {
|
||||||
|
|
||||||
// NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with
|
// NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with
|
||||||
// a set of Users and TailscaleClients.
|
// a set of Users and TailscaleClients.
|
||||||
func NewScenario() (*Scenario, error) {
|
func NewScenario(maxWait time.Duration) (*Scenario, error) {
|
||||||
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
|
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -152,7 +153,7 @@ func NewScenario() (*Scenario, error) {
|
||||||
return nil, fmt.Errorf("could not connect to docker: %w", err)
|
return nil, fmt.Errorf("could not connect to docker: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pool.MaxWait = dockertestMaxWait()
|
pool.MaxWait = maxWait
|
||||||
|
|
||||||
networkName := fmt.Sprintf("hs-%s", hash)
|
networkName := fmt.Sprintf("hs-%s", hash)
|
||||||
if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" {
|
if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" {
|
||||||
|
|
|
@ -33,7 +33,7 @@ func TestHeadscale(t *testing.T) {
|
||||||
|
|
||||||
user := "test-space"
|
user := "test-space"
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ func TestCreateTailscale(t *testing.T) {
|
||||||
|
|
||||||
user := "only-create-containers"
|
user := "only-create-containers"
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
|
||||||
|
|
||||||
count := 1
|
count := 1
|
||||||
|
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
defer scenario.Shutdown()
|
defer scenario.Shutdown()
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ var retry = func(times int, sleepInterval time.Duration,
|
||||||
|
|
||||||
func sshScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario {
|
func sshScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
scenario, err := NewScenario()
|
scenario, err := NewScenario(dockertestMaxWait())
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
spec := map[string]int{
|
spec := map[string]int{
|
||||||
|
|
Loading…
Reference in a new issue