diff --git a/api.go b/api.go index f875259..914a938 100644 --- a/api.go +++ b/api.go @@ -33,6 +33,8 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) { return } + // spew.Dump(c.Params) + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` @@ -71,6 +73,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { c.String(http.StatusInternalServerError, "Very sad!") return } + db, err := h.db() if err != nil { log.Printf("Cannot open DB: %s", err) @@ -93,6 +96,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { log.Println("Client is registered and we have the current key. All clear to /map") resp.AuthURL = "" resp.User = *m.Namespace.toUser() + resp.MachineAuthorized = true respBody, err := encode(resp, &mKey, h.privateKey) if err != nil { log.Printf("Cannot encode message: %s", err) @@ -135,6 +139,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { } log.Println("We dont know anything about the new key. WTF") + // spew.Dump(req) } // PollNetMapHandler takes care of /machine/:id/map @@ -359,21 +364,60 @@ func (h *Headscale) getMapKeepAliveResponse(mKey wgcfg.Key, req tailcfg.MapReque } func (h *Headscale) handleNewServer(c *gin.Context, db *gorm.DB, idKey wgcfg.Key, req tailcfg.RegisterRequest) { - mNew := Machine{ + m := Machine{ MachineKey: idKey.HexString(), NodeKey: wgcfg.Key(req.NodeKey).HexString(), Expiry: &req.Expiry, Name: req.Hostinfo.Hostname, } - if err := db.Create(&mNew).Error; err != nil { + if err := db.Create(&m).Error; err != nil { log.Printf("Could not create row: %s", err) return } - resp := tailcfg.RegisterResponse{ - AuthURL: fmt.Sprintf("%s/register?key=%s", - h.cfg.ServerURL, idKey.HexString()), + + resp := tailcfg.RegisterResponse{} + + if req.Auth.AuthKey != "" { + pak, err := h.checkKeyValidity(req.Auth.AuthKey) + if err != nil { + resp.MachineAuthorized = false + respBody, err := encode(resp, &idKey, h.privateKey) + if err != nil { + log.Printf("Cannot encode message: %s", err) + c.String(http.StatusInternalServerError, "") + return + } + c.Data(200, "application/json; charset=utf-8", respBody) + return + } + ip, err := h.getAvailableIP() + if err != nil { + log.Println(err) + return + } + + m.IPAddress = ip.String() + m.NamespaceID = pak.NamespaceID + m.AuthKeyID = uint(pak.ID) + m.RegisterMethod = "authKey" + m.Registered = true + db.Save(&m) + + resp.MachineAuthorized = true + resp.User = *pak.Namespace.toUser() + respBody, err := encode(resp, &idKey, h.privateKey) + if err != nil { + log.Printf("Cannot encode message: %s", err) + c.String(http.StatusInternalServerError, "Extremely sad!") + return + } + c.Data(200, "application/json; charset=utf-8", respBody) + return } + resp.AuthURL = fmt.Sprintf("%s/register?key=%s", + h.cfg.ServerURL, idKey.HexString()) + respBody, err := encode(resp, &idKey, h.privateKey) if err != nil { log.Printf("Cannot encode message: %s", err) diff --git a/cli.go b/cli.go index 0cb5333..f75bb47 100644 --- a/cli.go +++ b/cli.go @@ -43,6 +43,7 @@ func (h *Headscale) RegisterMachine(key string, namespace string) error { m.IPAddress = ip.String() m.NamespaceID = ns.ID m.Registered = true + m.RegisterMethod = "cli" db.Save(&m) fmt.Println("Machine registered 🎉") return nil diff --git a/machine.go b/machine.go index d72e660..b67c2a6 100644 --- a/machine.go +++ b/machine.go @@ -25,9 +25,13 @@ type Machine struct { NamespaceID uint Namespace Namespace - Registered bool // temp - LastSeen *time.Time - Expiry *time.Time + Registered bool // temp + RegisterMethod string + AuthKeyID uint + AuthKey *PreAuthKey + + LastSeen *time.Time + Expiry *time.Time HostInfo postgres.Jsonb Endpoints postgres.Jsonb diff --git a/preauth_keys.go b/preauth_keys.go index de89b04..7488a2e 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -7,6 +7,10 @@ import ( "time" ) +const errorAuthKeyNotFound = Error("AuthKey not found") +const errorAuthKeyExpired = Error("AuthKey expired") +const errorAuthKeyNotReusableAlreadyUsed = Error("AuthKey not reusable already used") + // PreAuthKey describes a pre-authorization key usable in a particular namespace type PreAuthKey struct { ID uint64 `gorm:"primary_key"` @@ -72,6 +76,41 @@ func (h *Headscale) GetPreAuthKeys(namespaceName string) (*[]PreAuthKey, error) return &keys, nil } +// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node +// If returns no error and a PreAuthKey, it can be used +func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { + db, err := h.db() + if err != nil { + return nil, err + } + defer db.Close() + + pak := PreAuthKey{} + if db.Preload("Namespace").First(&pak, "key = ?", k).RecordNotFound() { + return nil, errorAuthKeyNotFound + } + + if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { + return nil, errorAuthKeyExpired + } + + if pak.Reusable { // we don't need to check if has been used before + return &pak, nil + } + + machines := []Machine{} + if err := db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + return nil, err + } + + if len(machines) != 0 { + return nil, errorAuthKeyNotReusableAlreadyUsed + } + + // missing here validation on current usage + return &pak, nil +} + func (h *Headscale) generateKey() (string, error) { size := 24 bytes := make([]byte, size) diff --git a/preauth_keys_test.go b/preauth_keys_test.go index 5ac3bcf..cf13bb9 100644 --- a/preauth_keys_test.go +++ b/preauth_keys_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "os" "testing" + "time" _ "github.com/jinzhu/gorm/dialects/sqlite" // sql driver @@ -48,6 +49,7 @@ func (s *Suite) TearDownSuite(c *check.C) { func (*Suite) TestCreatePreAuthKey(c *check.C) { _, err := h.CreatePreAuthKey("bogus", true, nil) + c.Assert(err, check.NotNil) n, err := h.CreateNamespace("test") @@ -73,3 +75,106 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { // Make sure the Namespace association is populated c.Assert((*keys)[0].Namespace.Name, check.Equals, n.Name) } + +func (*Suite) TestExpiredPreAuthKey(c *check.C) { + n, err := h.CreateNamespace("test2") + c.Assert(err, check.IsNil) + + now := time.Now() + pak, err := h.CreatePreAuthKey(n.Name, true, &now) + c.Assert(err, check.IsNil) + + p, err := h.checkKeyValidity(pak.Key) + c.Assert(err, check.Equals, errorAuthKeyExpired) + c.Assert(p, check.IsNil) +} + +func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { + p, err := h.checkKeyValidity("potatoKey") + c.Assert(err, check.Equals, errorAuthKeyNotFound) + c.Assert(p, check.IsNil) +} + +func (*Suite) TestValidateKeyOk(c *check.C) { + n, err := h.CreateNamespace("test3") + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, true, nil) + c.Assert(err, check.IsNil) + + p, err := h.checkKeyValidity(pak.Key) + c.Assert(err, check.IsNil) + c.Assert(p.ID, check.Equals, pak.ID) +} + +func (*Suite) TestAlreadyUsedKey(c *check.C) { + n, err := h.CreateNamespace("test4") + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, false, nil) + c.Assert(err, check.IsNil) + + db, err := h.db() + if err != nil { + c.Fatal(err) + } + defer db.Close() + m := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testest", + NamespaceID: n.ID, + Registered: true, + RegisterMethod: "authKey", + AuthKeyID: uint(pak.ID), + } + db.Save(&m) + + p, err := h.checkKeyValidity(pak.Key) + c.Assert(err, check.Equals, errorAuthKeyNotReusableAlreadyUsed) + c.Assert(p, check.IsNil) +} + +func (*Suite) TestReusableBeingUsedKey(c *check.C) { + n, err := h.CreateNamespace("test5") + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, true, nil) + c.Assert(err, check.IsNil) + + db, err := h.db() + if err != nil { + c.Fatal(err) + } + defer db.Close() + m := Machine{ + ID: 1, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Name: "testest", + NamespaceID: n.ID, + Registered: true, + RegisterMethod: "authKey", + AuthKeyID: uint(pak.ID), + } + db.Save(&m) + + p, err := h.checkKeyValidity(pak.Key) + c.Assert(err, check.IsNil) + c.Assert(p.ID, check.Equals, pak.ID) +} + +func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { + n, err := h.CreateNamespace("test6") + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, false, nil) + c.Assert(err, check.IsNil) + + p, err := h.checkKeyValidity(pak.Key) + c.Assert(err, check.IsNil) + c.Assert(p.ID, check.Equals, pak.ID) +} diff --git a/utils.go b/utils.go index e787b1d..2785745 100644 --- a/utils.go +++ b/utils.go @@ -21,6 +21,11 @@ import ( "tailscale.com/wgengine/wgcfg" ) +// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors +type Error string + +func (e Error) Error() string { return string(e) } + func decode(msg []byte, v interface{}, pubKey *wgcfg.Key, privKey *wgcfg.PrivateKey) error { return decodeMsg(msg, v, pubKey, privKey) }