Remove redundant caches
This commit removes the two extra caches (oidc, requested time) and uses the new central registration cache instead. The requested time is unified into the main machine object and the oidc key is just added to the same cache, as a string with the state as a key instead of machine key.
This commit is contained in:
parent
e64bee778f
commit
5e92ddad43
6 changed files with 27 additions and 84 deletions
25
api.go
25
api.go
|
@ -140,17 +140,25 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
|
||||||
// We create the machine and then keep it around until a callback
|
// We create the machine and then keep it around until a callback
|
||||||
// happens
|
// happens
|
||||||
newMachine := Machine{
|
newMachine := Machine{
|
||||||
Expiry: &time.Time{},
|
|
||||||
MachineKey: machineKeyStr,
|
MachineKey: machineKeyStr,
|
||||||
Name: req.Hostinfo.Hostname,
|
Name: req.Hostinfo.Hostname,
|
||||||
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
|
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
|
||||||
LastSeen: &now,
|
LastSeen: &now,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !req.Expiry.IsZero() {
|
||||||
|
log.Trace().
|
||||||
|
Caller().
|
||||||
|
Str("machine", req.Hostinfo.Hostname).
|
||||||
|
Time("expiry", req.Expiry).
|
||||||
|
Msg("Non-zero expiry time requested")
|
||||||
|
newMachine.Expiry = &req.Expiry
|
||||||
|
}
|
||||||
|
|
||||||
h.registrationCache.Set(
|
h.registrationCache.Set(
|
||||||
machineKeyStr,
|
machineKeyStr,
|
||||||
newMachine,
|
newMachine,
|
||||||
requestedExpiryCacheExpiration,
|
registerCacheExpiration,
|
||||||
)
|
)
|
||||||
|
|
||||||
h.handleMachineRegistrationNew(ctx, machineKey, req)
|
h.handleMachineRegistrationNew(ctx, machineKey, req)
|
||||||
|
@ -490,19 +498,6 @@ func (h *Headscale) handleMachineRegistrationNew(
|
||||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
|
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !registerRequest.Expiry.IsZero() {
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
|
||||||
Time("expiry", registerRequest.Expiry).
|
|
||||||
Msg("Non-zero expiry time requested, adding to cache")
|
|
||||||
h.requestedExpiryCache.Set(
|
|
||||||
machineKey.String(),
|
|
||||||
registerRequest.Expiry,
|
|
||||||
requestedExpiryCacheExpiration,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
respBody, err := encode(resp, &machineKey, h.privateKey)
|
respBody, err := encode(resp, &machineKey, h.privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
|
17
app.go
17
app.go
|
@ -55,8 +55,8 @@ const (
|
||||||
HTTPReadTimeout = 30 * time.Second
|
HTTPReadTimeout = 30 * time.Second
|
||||||
privateKeyFileMode = 0o600
|
privateKeyFileMode = 0o600
|
||||||
|
|
||||||
requestedExpiryCacheExpiration = time.Minute * 5
|
registerCacheExpiration = time.Minute * 15
|
||||||
requestedExpiryCacheCleanupInterval = time.Minute * 10
|
registerCacheCleanup = time.Minute * 20
|
||||||
|
|
||||||
errUnsupportedDatabase = Error("unsupported DB")
|
errUnsupportedDatabase = Error("unsupported DB")
|
||||||
errUnsupportedLetsEncryptChallengeType = Error(
|
errUnsupportedLetsEncryptChallengeType = Error(
|
||||||
|
@ -150,9 +150,6 @@ type Headscale struct {
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oauth2Config *oauth2.Config
|
oauth2Config *oauth2.Config
|
||||||
oidcStateCache *cache.Cache
|
|
||||||
|
|
||||||
requestedExpiryCache *cache.Cache
|
|
||||||
|
|
||||||
registrationCache *cache.Cache
|
registrationCache *cache.Cache
|
||||||
|
|
||||||
|
@ -204,15 +201,10 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
return nil, errUnsupportedDatabase
|
return nil, errUnsupportedDatabase
|
||||||
}
|
}
|
||||||
|
|
||||||
requestedExpiryCache := cache.New(
|
|
||||||
requestedExpiryCacheExpiration,
|
|
||||||
requestedExpiryCacheCleanupInterval,
|
|
||||||
)
|
|
||||||
|
|
||||||
registrationCache := cache.New(
|
registrationCache := cache.New(
|
||||||
// TODO(kradalby): Add unified cache expiry config options
|
// TODO(kradalby): Add unified cache expiry config options
|
||||||
requestedExpiryCacheExpiration,
|
registerCacheExpiration,
|
||||||
requestedExpiryCacheCleanupInterval,
|
registerCacheCleanup,
|
||||||
)
|
)
|
||||||
|
|
||||||
app := Headscale{
|
app := Headscale{
|
||||||
|
@ -221,7 +213,6 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||||
dbString: dbString,
|
dbString: dbString,
|
||||||
privateKey: privKey,
|
privateKey: privKey,
|
||||||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||||
requestedExpiryCache: requestedExpiryCache,
|
|
||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
)
|
)
|
||||||
|
@ -50,10 +49,6 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
dbType: "sqlite3",
|
dbType: "sqlite3",
|
||||||
dbString: tmpDir + "/headscale_test.db",
|
dbString: tmpDir + "/headscale_test.db",
|
||||||
requestedExpiryCache: cache.New(
|
|
||||||
requestedExpiryCacheExpiration,
|
|
||||||
requestedExpiryCacheCleanupInterval,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
err = app.initDB()
|
err = app.initDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
17
grpcv1.go
17
grpcv1.go
|
@ -160,25 +160,10 @@ func (api headscaleV1APIServer) RegisterMachine(
|
||||||
Str("machine_key", request.GetKey()).
|
Str("machine_key", request.GetKey()).
|
||||||
Msg("Registering machine")
|
Msg("Registering machine")
|
||||||
|
|
||||||
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
|
|
||||||
// This means that if a user is to slow with register a machine, it will possibly not
|
|
||||||
// have the correct expiry.
|
|
||||||
requestedTime := time.Time{}
|
|
||||||
if requestedTimeIf, found := api.h.requestedExpiryCache.Get(request.GetKey()); found {
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("machine", request.Key).
|
|
||||||
Msg("Expiry time found in cache, assigning to node")
|
|
||||||
if reqTime, ok := requestedTimeIf.(time.Time); ok {
|
|
||||||
requestedTime = reqTime
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
machine, err := api.h.RegisterMachineFromAuthCallback(
|
machine, err := api.h.RegisterMachineFromAuthCallback(
|
||||||
request.GetKey(),
|
request.GetKey(),
|
||||||
request.GetNamespace(),
|
request.GetNamespace(),
|
||||||
RegisterMethodCLI,
|
RegisterMethodCLI,
|
||||||
&requestedTime,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -418,7 +403,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||||
api.h.registrationCache.Set(
|
api.h.registrationCache.Set(
|
||||||
request.GetKey(),
|
request.GetKey(),
|
||||||
newMachine,
|
newMachine,
|
||||||
requestedExpiryCacheExpiration,
|
registerCacheExpiration,
|
||||||
)
|
)
|
||||||
|
|
||||||
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil
|
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil
|
||||||
|
|
|
@ -683,7 +683,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
|
||||||
machineKeyStr string,
|
machineKeyStr string,
|
||||||
namespaceName string,
|
namespaceName string,
|
||||||
registrationMethod string,
|
registrationMethod string,
|
||||||
expiry *time.Time,
|
|
||||||
) (*Machine, error) {
|
) (*Machine, error) {
|
||||||
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
|
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
|
||||||
if registrationMachine, ok := machineInterface.(Machine); ok {
|
if registrationMachine, ok := machineInterface.(Machine); ok {
|
||||||
|
@ -697,7 +696,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
|
||||||
|
|
||||||
registrationMachine.NamespaceID = namespace.ID
|
registrationMachine.NamespaceID = namespace.ID
|
||||||
registrationMachine.RegisterMethod = registrationMethod
|
registrationMachine.RegisterMethod = registrationMethod
|
||||||
registrationMachine.Expiry = expiry
|
|
||||||
|
|
||||||
machine, err := h.RegisterMachine(
|
machine, err := h.RegisterMachine(
|
||||||
registrationMachine,
|
registrationMachine,
|
||||||
|
|
27
oidc.go
27
oidc.go
|
@ -10,19 +10,15 @@ import (
|
||||||
"html/template"
|
"html/template"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/patrickmn/go-cache"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
oidcStateCacheExpiration = time.Minute * 5
|
|
||||||
oidcStateCacheCleanupInterval = time.Minute * 10
|
|
||||||
randomByteSize = 16
|
randomByteSize = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,14 +56,6 @@ func (h *Headscale) initOIDC() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// init the state cache if it hasn't been already
|
|
||||||
if h.oidcStateCache == nil {
|
|
||||||
h.oidcStateCache = cache.New(
|
|
||||||
oidcStateCacheExpiration,
|
|
||||||
oidcStateCacheCleanupInterval,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,7 +88,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
|
||||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||||
|
|
||||||
// place the machine key into the state cache, so it can be retrieved later
|
// place the machine key into the state cache, so it can be retrieved later
|
||||||
h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration)
|
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
|
||||||
|
|
||||||
authURL := h.oauth2Config.AuthCodeURL(stateStr)
|
authURL := h.oauth2Config.AuthCodeURL(stateStr)
|
||||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||||
|
@ -196,7 +184,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve machinekey from state cache
|
// retrieve machinekey from state cache
|
||||||
machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state)
|
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
||||||
|
|
||||||
if !machineKeyFound {
|
if !machineKeyFound {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
@ -228,14 +216,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
|
|
||||||
requestedTime := time.Time{}
|
|
||||||
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found {
|
|
||||||
if reqTime, ok := requestedTimeIf.(time.Time); ok {
|
|
||||||
requestedTime = reqTime
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// retrieve machine information
|
// retrieve machine information
|
||||||
machine, err := h.GetMachineByMachineKey(machineKey)
|
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -254,7 +234,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
Str("machine", machine.Name).
|
Str("machine", machine.Name).
|
||||||
Msg("machine already registered, reauthenticating")
|
Msg("machine already registered, reauthenticating")
|
||||||
|
|
||||||
h.RefreshMachine(machine, requestedTime)
|
h.RefreshMachine(machine, *machine.Expiry)
|
||||||
|
|
||||||
var content bytes.Buffer
|
var content bytes.Buffer
|
||||||
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
||||||
|
@ -329,7 +309,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||||
machineKeyStr,
|
machineKeyStr,
|
||||||
namespace.Name,
|
namespace.Name,
|
||||||
RegisterMethodOIDC,
|
RegisterMethodOIDC,
|
||||||
&requestedTime,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
|
|
Loading…
Reference in a new issue