Merge pull request #29 from cure/add-ephemeral-node-support
Add support for ephemeral nodes via a special type of pre-auth key.
This commit is contained in:
commit
6a3b171e99
17 changed files with 151 additions and 38 deletions
|
@ -22,7 +22,7 @@ Headscale implements this coordination server.
|
||||||
- [x] ~~Multiuser~~ Namespace support
|
- [x] ~~Multiuser~~ Namespace support
|
||||||
- [x] Basic routing (advertise & accept)
|
- [x] Basic routing (advertise & accept)
|
||||||
- [ ] Share nodes between ~~users~~ namespaces
|
- [ ] Share nodes between ~~users~~ namespaces
|
||||||
- [x] Node registration via pre-auth keys
|
- [x] Node registration via pre-auth keys (including reusable keys and ephemeral node support)
|
||||||
- [X] JSON-formatted output
|
- [X] JSON-formatted output
|
||||||
- [ ] ACLs
|
- [ ] ACLs
|
||||||
- [ ] DNS
|
- [ ] DNS
|
||||||
|
@ -97,6 +97,7 @@ Alternatively, you can use Auth Keys to register your machines:
|
||||||
tailscale up -login-server YOUR_HEADSCALE_URL --authkey YOURAUTHKEY
|
tailscale up -login-server YOUR_HEADSCALE_URL --authkey YOURAUTHKEY
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you create an authkey with the `--ephemeral` flag, that key will create ephemeral nodes. This implies that `--reusable` is true.
|
||||||
|
|
||||||
Please bear in mind that all the commands from headscale support adding `-o json` or `-o json-line` to get a nicely JSON-formatted output.
|
Please bear in mind that all the commands from headscale support adding `-o json` or `-o json-line` to get a nicely JSON-formatted output.
|
||||||
|
|
||||||
|
@ -124,6 +125,12 @@ Headscale's configuration file is named `config.json` or `config.yaml`. Headscal
|
||||||
|
|
||||||
`derp_map_path` is the path to the [DERP](https://pkg.go.dev/tailscale.com/derp) map file. If the path is relative, it will be interpreted as relative to the directory the configuration file was read from.
|
`derp_map_path` is the path to the [DERP](https://pkg.go.dev/tailscale.com/derp) map file. If the path is relative, it will be interpreted as relative to the directory the configuration file was read from.
|
||||||
|
|
||||||
|
```
|
||||||
|
"ephemeral_node_inactivity_timeout": "30m",
|
||||||
|
```
|
||||||
|
|
||||||
|
`ephemeral_node_inactivity_timeout` is the timeout after which inactive ephemeral node records will be deleted from the database. The default is 30 minutes. This value must be higher than 65 seconds (the keepalive timeout for the HTTP long poll is 60 seconds, plus a few seconds to avoid race conditions).
|
||||||
|
|
||||||
```
|
```
|
||||||
"db_host": "localhost",
|
"db_host": "localhost",
|
||||||
"db_port": 5432,
|
"db_port": 5432,
|
||||||
|
|
17
api.go
17
api.go
|
@ -93,7 +93,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||||
// We do have the updated key!
|
// We do have the updated key!
|
||||||
if m.NodeKey == wgcfg.Key(req.NodeKey).HexString() {
|
if m.NodeKey == wgcfg.Key(req.NodeKey).HexString() {
|
||||||
if m.Registered {
|
if m.Registered {
|
||||||
log.Println("Client is registered and we have the current key. All clear to /map")
|
log.Printf("[%s] Client is registered and we have the current key. All clear to /map\n", m.Name)
|
||||||
resp.AuthURL = ""
|
resp.AuthURL = ""
|
||||||
resp.User = *m.Namespace.toUser()
|
resp.User = *m.Namespace.toUser()
|
||||||
resp.MachineAuthorized = true
|
resp.MachineAuthorized = true
|
||||||
|
@ -174,7 +174,7 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
var m Machine
|
var m Machine
|
||||||
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
|
if db.First(&m, "machine_key = ?", mKey.HexString()).RecordNotFound() {
|
||||||
log.Printf("Cannot fingitd machine: %s", err)
|
log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -243,29 +243,34 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
log.Printf("[%s] Sending data (%d bytes)", m.Name, len(data))
|
log.Printf("[%s] Sending data (%d bytes)", m.Name, len(data))
|
||||||
_, err := w.Write(data)
|
_, err := w.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[%s] 🤮 Cannot write data: %s", m.Name, err)
|
log.Printf("[%s] 🤮 Cannot write data: %s", m.Name, err)
|
||||||
}
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
m.LastSeen = &now
|
||||||
|
db.Save(&m)
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case <-update:
|
case <-update:
|
||||||
log.Printf("[%s] Received a request for update", m.Name)
|
log.Printf("[%s] Received a request for update", m.Name)
|
||||||
data, err := h.getMapResponse(mKey, req, m)
|
data, err := h.getMapResponse(mKey, req, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[%s] 🤮 Cannot get the poll response: %s", m.Name, err)
|
log.Printf("[%s] 🤮 Cannot get the poll response: %s", m.Name, err)
|
||||||
}
|
}
|
||||||
_, err = w.Write(*data)
|
_, err = w.Write(*data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[%s] 🤮 Cannot write the poll response: %s", m.Name, err)
|
log.Printf("[%s] 🤮 Cannot write the poll response: %s", m.Name, err)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case <-c.Request.Context().Done():
|
case <-c.Request.Context().Done():
|
||||||
log.Printf("[%s] 😥 The client has closed the connection", m.Name)
|
log.Printf("[%s] 😥 The client has closed the connection", m.Name)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
m.LastSeen = &now
|
||||||
|
db.Save(&m)
|
||||||
h.pollMu.Lock()
|
h.pollMu.Lock()
|
||||||
cancelKeepAlive <- []byte{}
|
cancelKeepAlive <- []byte{}
|
||||||
delete(h.clientsPolling, m.ID)
|
delete(h.clientsPolling, m.ID)
|
||||||
h.pollMu.Unlock()
|
h.pollMu.Unlock()
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
53
app.go
53
app.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
@ -21,6 +22,7 @@ type Config struct {
|
||||||
Addr string
|
Addr string
|
||||||
PrivateKeyPath string
|
PrivateKeyPath string
|
||||||
DerpMap *tailcfg.DERPMap
|
DerpMap *tailcfg.DERPMap
|
||||||
|
EphemeralNodeInactivityTimeout time.Duration
|
||||||
|
|
||||||
DBtype string
|
DBtype string
|
||||||
DBpath string
|
DBpath string
|
||||||
|
@ -95,6 +97,51 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
||||||
http.Redirect(w, req, target, http.StatusFound)
|
http.Redirect(w, req, target, http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExpireEphemeralNodes deletes ephemeral machine records that have not been
|
||||||
|
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout
|
||||||
|
func (h *Headscale) ExpireEphemeralNodes(milliSeconds int64) {
|
||||||
|
if milliSeconds == 0 {
|
||||||
|
// For testing
|
||||||
|
h.expireEphemeralNodesWorker()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||||
|
for range ticker.C {
|
||||||
|
h.expireEphemeralNodesWorker()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) expireEphemeralNodesWorker() {
|
||||||
|
db, err := h.db()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot open DB: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
namespaces, err := h.ListNamespaces()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error listing namespaces: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, ns := range *namespaces {
|
||||||
|
machines, err := h.ListMachinesInNamespace(ns.Name)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error listing machines in namespace %s: %s", ns.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, m := range *machines {
|
||||||
|
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||||
|
log.Printf("[%s] Ephemeral client removed from database\n", m.Name)
|
||||||
|
err = db.Unscoped().Delete(m).Error
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[%s] 🤮 Cannot delete ephemeral machine from the database: %s", m.Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Serve launches a GIN server with the Headscale API
|
// Serve launches a GIN server with the Headscale API
|
||||||
func (h *Headscale) Serve() error {
|
func (h *Headscale) Serve() error {
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
|
@ -105,7 +152,7 @@ func (h *Headscale) Serve() error {
|
||||||
var err error
|
var err error
|
||||||
if h.cfg.TLSLetsEncryptHostname != "" {
|
if h.cfg.TLSLetsEncryptHostname != "" {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||||
fmt.Println("WARNING: listening with TLS but ServerURL does not start with https://")
|
log.Println("WARNING: listening with TLS but ServerURL does not start with https://")
|
||||||
}
|
}
|
||||||
|
|
||||||
m := autocert.Manager{
|
m := autocert.Manager{
|
||||||
|
@ -136,12 +183,12 @@ func (h *Headscale) Serve() error {
|
||||||
}
|
}
|
||||||
} else if h.cfg.TLSCertPath == "" {
|
} else if h.cfg.TLSCertPath == "" {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
||||||
fmt.Println("WARNING: listening without TLS but ServerURL does not start with http://")
|
log.Println("WARNING: listening without TLS but ServerURL does not start with http://")
|
||||||
}
|
}
|
||||||
err = r.Run(h.cfg.Addr)
|
err = r.Run(h.cfg.Addr)
|
||||||
} else {
|
} else {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||||
fmt.Println("WARNING: listening with TLS but ServerURL does not start with https://")
|
log.Println("WARNING: listening with TLS but ServerURL does not start with https://")
|
||||||
}
|
}
|
||||||
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
|
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,9 +65,9 @@ var ListNodesCmd = &cobra.Command{
|
||||||
log.Fatalf("Error getting nodes: %s", err)
|
log.Fatalf("Error getting nodes: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("name\t\tlast seen\n")
|
fmt.Printf("name\t\tlast seen\t\tephemeral\n")
|
||||||
for _, m := range *machines {
|
for _, m := range *machines {
|
||||||
fmt.Printf("%s\t%s\n", m.Name, m.LastSeen.Format("2006-01-02 15:04:05"))
|
fmt.Printf("%s\t%s\t%t\n", m.Name, m.LastSeen.Format("2006-01-02 15:04:05"), m.AuthKey.Ephemeral)
|
||||||
}
|
}
|
||||||
|
|
||||||
},
|
},
|
||||||
|
|
|
@ -45,10 +45,11 @@ var ListPreAuthKeys = &cobra.Command{
|
||||||
expiration = k.Expiration.Format("2006-01-02 15:04:05")
|
expiration = k.Expiration.Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
fmt.Printf(
|
fmt.Printf(
|
||||||
"key: %s, namespace: %s, reusable: %v, expiration: %s, created_at: %s\n",
|
"key: %s, namespace: %s, reusable: %v, ephemeral: %v, expiration: %s, created_at: %s\n",
|
||||||
k.Key,
|
k.Key,
|
||||||
k.Namespace.Name,
|
k.Namespace.Name,
|
||||||
k.Reusable,
|
k.Reusable,
|
||||||
|
k.Ephemeral,
|
||||||
expiration,
|
expiration,
|
||||||
k.CreatedAt.Format("2006-01-02 15:04:05"),
|
k.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
)
|
)
|
||||||
|
@ -71,6 +72,7 @@ var CreatePreAuthKeyCmd = &cobra.Command{
|
||||||
log.Fatalf("Error initializing: %s", err)
|
log.Fatalf("Error initializing: %s", err)
|
||||||
}
|
}
|
||||||
reusable, _ := cmd.Flags().GetBool("reusable")
|
reusable, _ := cmd.Flags().GetBool("reusable")
|
||||||
|
ephemeral, _ := cmd.Flags().GetBool("ephemeral")
|
||||||
|
|
||||||
e, _ := cmd.Flags().GetString("expiration")
|
e, _ := cmd.Flags().GetString("expiration")
|
||||||
var expiration *time.Time
|
var expiration *time.Time
|
||||||
|
@ -83,7 +85,7 @@ var CreatePreAuthKeyCmd = &cobra.Command{
|
||||||
expiration = &exp
|
expiration = &exp
|
||||||
}
|
}
|
||||||
|
|
||||||
k, err := h.CreatePreAuthKey(n, reusable, expiration)
|
k, err := h.CreatePreAuthKey(n, reusable, ephemeral, expiration)
|
||||||
if strings.HasPrefix(o, "json") {
|
if strings.HasPrefix(o, "json") {
|
||||||
JsonOutput(k, err, o)
|
JsonOutput(k, err, o)
|
||||||
return
|
return
|
||||||
|
|
|
@ -17,6 +17,7 @@ var ServeCmd = &cobra.Command{
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error initializing: %s", err)
|
log.Fatalf("Error initializing: %s", err)
|
||||||
}
|
}
|
||||||
|
go h.ExpireEphemeralNodes(5000)
|
||||||
err = h.Serve()
|
err = h.Serve()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error initializing: %s", err)
|
log.Fatalf("Error initializing: %s", err)
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale"
|
"github.com/juanfont/headscale"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
|
@ -37,12 +38,22 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||||
log.Printf("Could not load DERP servers map file: %s", err)
|
log.Printf("Could not load DERP servers map file: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
|
||||||
|
// to avoid races
|
||||||
|
minInactivityTimeout, _ := time.ParseDuration("65s")
|
||||||
|
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
|
||||||
|
err = fmt.Errorf("ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n", viper.GetString("ephemeral_node_inactivity_timeout"), minInactivityTimeout)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
cfg := headscale.Config{
|
cfg := headscale.Config{
|
||||||
ServerURL: viper.GetString("server_url"),
|
ServerURL: viper.GetString("server_url"),
|
||||||
Addr: viper.GetString("listen_addr"),
|
Addr: viper.GetString("listen_addr"),
|
||||||
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
|
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
|
||||||
DerpMap: derpMap,
|
DerpMap: derpMap,
|
||||||
|
|
||||||
|
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
|
||||||
|
|
||||||
DBtype: viper.GetString("db_type"),
|
DBtype: viper.GetString("db_type"),
|
||||||
DBpath: absPath(viper.GetString("db_path")),
|
DBpath: absPath(viper.GetString("db_path")),
|
||||||
DBhost: viper.GetString("db_host"),
|
DBhost: viper.GetString("db_host"),
|
||||||
|
|
|
@ -127,6 +127,7 @@ func main() {
|
||||||
cli.PreauthkeysCmd.AddCommand(cli.CreatePreAuthKeyCmd)
|
cli.PreauthkeysCmd.AddCommand(cli.CreatePreAuthKeyCmd)
|
||||||
|
|
||||||
cli.CreatePreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable")
|
cli.CreatePreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable")
|
||||||
|
cli.CreatePreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
|
||||||
cli.CreatePreAuthKeyCmd.Flags().StringP("expiration", "e", "", "Human-readable expiration of the key (30m, 24h, 365d...)")
|
cli.CreatePreAuthKeyCmd.Flags().StringP("expiration", "e", "", "Human-readable expiration of the key (30m, 24h, 365d...)")
|
||||||
|
|
||||||
headscaleCmd.PersistentFlags().StringP("output", "o", "", "Output format. Empty for human-readable, 'json' or 'json-line'")
|
headscaleCmd.PersistentFlags().StringP("output", "o", "", "Output format. Empty for human-readable, 'json' or 'json-line'")
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
"listen_addr": "0.0.0.0:8000",
|
"listen_addr": "0.0.0.0:8000",
|
||||||
"private_key_path": "private.key",
|
"private_key_path": "private.key",
|
||||||
"derp_map_path": "derp.yaml",
|
"derp_map_path": "derp.yaml",
|
||||||
|
"ephemeral_node_inactivity_timeout": "30m",
|
||||||
"db_type": "postgres",
|
"db_type": "postgres",
|
||||||
"db_host": "localhost",
|
"db_host": "localhost",
|
||||||
"db_port": 5432,
|
"db_port": 5432,
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
"listen_addr": "0.0.0.0:8000",
|
"listen_addr": "0.0.0.0:8000",
|
||||||
"private_key_path": "private.key",
|
"private_key_path": "private.key",
|
||||||
"derp_map_path": "derp.yaml",
|
"derp_map_path": "derp.yaml",
|
||||||
|
"ephemeral_node_inactivity_timeout": "30m",
|
||||||
"db_type": "sqlite3",
|
"db_type": "sqlite3",
|
||||||
"db_path": "db.sqlite",
|
"db_path": "db.sqlite",
|
||||||
"tls_letsencrypt_hostname": "",
|
"tls_letsencrypt_hostname": "",
|
||||||
|
|
|
@ -131,7 +131,7 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
|
||||||
|
|
||||||
n := tailcfg.Node{
|
n := tailcfg.Node{
|
||||||
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
||||||
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permantent
|
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||||
Name: hostinfo.Hostname,
|
Name: hostinfo.Hostname,
|
||||||
User: tailcfg.UserID(m.NamespaceID),
|
User: tailcfg.UserID(m.NamespaceID),
|
||||||
Key: tailcfg.NodeKey(nKey),
|
Key: tailcfg.NodeKey(nKey),
|
||||||
|
|
|
@ -4,13 +4,11 @@ import (
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = check.Suite(&Suite{})
|
|
||||||
|
|
||||||
func (s *Suite) TestGetMachine(c *check.C) {
|
func (s *Suite) TestGetMachine(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
n, err := h.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, nil)
|
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db, err := h.db()
|
db, err := h.db()
|
||||||
|
|
|
@ -119,7 +119,7 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) {
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := db.Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
|
if err := db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &machines, nil
|
return &machines, nil
|
||||||
|
|
|
@ -4,8 +4,6 @@ import (
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = check.Suite(&Suite{})
|
|
||||||
|
|
||||||
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
|
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
n, err := h.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -29,7 +27,7 @@ func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
n, err := h.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, nil)
|
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db, err := h.db()
|
db, err := h.db()
|
||||||
|
|
|
@ -18,13 +18,14 @@ type PreAuthKey struct {
|
||||||
NamespaceID uint
|
NamespaceID uint
|
||||||
Namespace Namespace
|
Namespace Namespace
|
||||||
Reusable bool
|
Reusable bool
|
||||||
|
Ephemeral bool `gorm:"default:false"`
|
||||||
|
|
||||||
CreatedAt *time.Time
|
CreatedAt *time.Time
|
||||||
Expiration *time.Time
|
Expiration *time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it
|
// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it
|
||||||
func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, expiration *time.Time) (*PreAuthKey, error) {
|
func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, ephemeral bool, expiration *time.Time) (*PreAuthKey, error) {
|
||||||
n, err := h.GetNamespace(namespaceName)
|
n, err := h.GetNamespace(namespaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -48,6 +49,7 @@ func (h *Headscale) CreatePreAuthKey(namespaceName string, reusable bool, expira
|
||||||
NamespaceID: n.ID,
|
NamespaceID: n.ID,
|
||||||
Namespace: *n,
|
Namespace: *n,
|
||||||
Reusable: reusable,
|
Reusable: reusable,
|
||||||
|
Ephemeral: ephemeral,
|
||||||
CreatedAt: &now,
|
CreatedAt: &now,
|
||||||
Expiration: expiration,
|
Expiration: expiration,
|
||||||
}
|
}
|
||||||
|
@ -94,7 +96,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||||
return nil, errorAuthKeyExpired
|
return nil, errorAuthKeyExpired
|
||||||
}
|
}
|
||||||
|
|
||||||
if pak.Reusable { // we don't need to check if has been used before
|
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
|
||||||
return &pak, nil
|
return &pak, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,14 +7,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||||
_, err := h.CreatePreAuthKey("bogus", true, nil)
|
_, err := h.CreatePreAuthKey("bogus", true, false, nil)
|
||||||
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
n, err := h.CreateNamespace("test")
|
n, err := h.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
k, err := h.CreatePreAuthKey(n.Name, true, nil)
|
k, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
// Did we get a valid key?
|
// Did we get a valid key?
|
||||||
|
@ -40,7 +40,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, true, &now)
|
pak, err := h.CreatePreAuthKey(n.Name, true, false, &now)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
p, err := h.checkKeyValidity(pak.Key)
|
||||||
|
@ -58,7 +58,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test3")
|
n, err := h.CreateNamespace("test3")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, true, nil)
|
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
p, err := h.checkKeyValidity(pak.Key)
|
||||||
|
@ -70,7 +70,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test4")
|
n, err := h.CreateNamespace("test4")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, nil)
|
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db, err := h.db()
|
db, err := h.db()
|
||||||
|
@ -100,7 +100,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test5")
|
n, err := h.CreateNamespace("test5")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, true, nil)
|
pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db, err := h.db()
|
db, err := h.db()
|
||||||
|
@ -130,10 +130,51 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test6")
|
n, err := h.CreateNamespace("test6")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, nil)
|
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
p, err := h.checkKeyValidity(pak.Key)
|
p, err := h.checkKeyValidity(pak.Key)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(p.ID, check.Equals, pak.ID)
|
c.Assert(p.ID, check.Equals, pak.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*Suite) TestEphemeralKey(c *check.C) {
|
||||||
|
n, err := h.CreateNamespace("test7")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := h.CreatePreAuthKey(n.Name, false, true, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
db, err := h.db()
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
now := time.Now()
|
||||||
|
m := Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Name: "testest",
|
||||||
|
NamespaceID: n.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
LastSeen: &now,
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
}
|
||||||
|
db.Save(&m)
|
||||||
|
|
||||||
|
_, err = h.checkKeyValidity(pak.Key)
|
||||||
|
// Ephemeral keys are by definition reusable
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = h.GetMachine("test7", "testest")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
h.ExpireEphemeralNodes(0)
|
||||||
|
|
||||||
|
// The machine record should have been deleted
|
||||||
|
_, err = h.GetMachine("test7", "testest")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
}
|
||||||
|
|
|
@ -9,13 +9,11 @@ import (
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = check.Suite(&Suite{})
|
|
||||||
|
|
||||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||||
n, err := h.CreateNamespace("test")
|
n, err := h.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
pak, err := h.CreatePreAuthKey(n.Name, false, nil)
|
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
db, err := h.db()
|
db, err := h.db()
|
||||||
|
|
Loading…
Reference in a new issue