diff --git a/cli.go b/cli.go index 6615d3d..0cb5333 100644 --- a/cli.go +++ b/cli.go @@ -1,13 +1,9 @@ package headscale import ( - "encoding/json" - "errors" "fmt" "log" - "github.com/jinzhu/gorm/dialects/postgres" - "inet.af/netaddr" "tailscale.com/wgengine/wgcfg" ) @@ -51,63 +47,3 @@ func (h *Headscale) RegisterMachine(key string, namespace string) error { fmt.Println("Machine registered 🎉") return nil } - -// ListNodeRoutes prints the subnet routes advertised by a node (identified by -// namespace and node name) -func (h *Headscale) ListNodeRoutes(namespace string, nodeName string) error { - m, err := h.GetMachine(namespace, nodeName) - if err != nil { - return err - } - - hi, err := m.GetHostInfo() - if err != nil { - return err - } - fmt.Println(hi.RoutableIPs) - return nil -} - -// EnableNodeRoute enables a subnet route advertised by a node (identified by -// namespace and node name) -func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error { - m, err := h.GetMachine(namespace, nodeName) - if err != nil { - return err - } - hi, err := m.GetHostInfo() - if err != nil { - return err - } - route, err := netaddr.ParseIPPrefix(routeStr) - if err != nil { - return err - } - - for _, rIP := range hi.RoutableIPs { - if rIP == route { - db, err := h.db() - if err != nil { - log.Printf("Cannot open DB: %s", err) - return err - } - - routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest - m.EnabledRoutes = postgres.Jsonb{RawMessage: json.RawMessage(routes)} - db.Save(&m) - db.Close() - - peers, _ := h.getPeers(*m) - h.pollMu.Lock() - for _, p := range *peers { - if pUp, ok := h.clientsPolling[uint64(p.ID)]; ok { - pUp <- []byte{} - } - } - h.pollMu.Unlock() - return nil - } - } - - return errors.New("could not find routable range") -} diff --git a/cmd/headscale/cli/namespaces.go b/cmd/headscale/cli/namespaces.go new file mode 100644 index 0000000..8240187 --- /dev/null +++ b/cmd/headscale/cli/namespaces.go @@ -0,0 +1,56 @@ +package cli + +import ( + "fmt" + "log" + + "github.com/spf13/cobra" +) + +var NamespaceCmd = &cobra.Command{ + Use: "namespace", + Short: "Manage the namespaces of Headscale", +} + +var CreateNamespaceCmd = &cobra.Command{ + Use: "create NAME", + Short: "Creates a new namespace", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) < 1 { + return fmt.Errorf("Missing parameters") + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + h, err := getHeadscaleApp() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + _, err = h.CreateNamespace(args[0]) + if err != nil { + fmt.Println(err) + return + } + fmt.Printf("Ook.\n") + }, +} + +var ListNamespacesCmd = &cobra.Command{ + Use: "list", + Short: "List all the namespaces", + Run: func(cmd *cobra.Command, args []string) { + h, err := getHeadscaleApp() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + ns, err := h.ListNamespaces() + if err != nil { + fmt.Println(err) + return + } + fmt.Printf("ID\tName\n") + for _, n := range *ns { + fmt.Printf("%d\t%s\n", n.ID, n.Name) + } + }, +} diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go new file mode 100644 index 0000000..299211d --- /dev/null +++ b/cmd/headscale/cli/nodes.go @@ -0,0 +1,36 @@ +package cli + +import ( + "fmt" + "log" + + "github.com/spf13/cobra" +) + +var RegisterCmd = &cobra.Command{ + Use: "register machineID namespace", + Short: "Registers a machine to your network", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) < 2 { + return fmt.Errorf("Missing parameters") + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + h, err := getHeadscaleApp() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + err = h.RegisterMachine(args[0], args[1]) + if err != nil { + fmt.Printf("Error: %s", err) + return + } + fmt.Println("Ook.") + }, +} + +var NodeCmd = &cobra.Command{ + Use: "node", + Short: "Manage the nodes of Headscale", +} diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go new file mode 100644 index 0000000..91225aa --- /dev/null +++ b/cmd/headscale/cli/preauthkeys.go @@ -0,0 +1,83 @@ +package cli + +import ( + "fmt" + "log" + "time" + + "github.com/hako/durafmt" + "github.com/spf13/cobra" +) + +var PreauthkeysCmd = &cobra.Command{ + Use: "preauthkey", + Short: "Handle the preauthkeys in Headscale", +} + +var ListPreAuthKeys = &cobra.Command{ + Use: "list NAMESPACE", + Short: "List the preauthkeys for this namespace", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) < 1 { + return fmt.Errorf("Missing parameters") + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + h, err := getHeadscaleApp() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + keys, err := h.GetPreAuthKeys(args[0]) + if err != nil { + fmt.Println(err) + return + } + for _, k := range *keys { + fmt.Printf( + "key: %s, namespace: %s, reusable: %v, expiration: %s, created_at: %s\n", + k.Key, + k.Namespace.Name, + k.Reusable, + k.Expiration.Format("2006-01-02 15:04:05"), + k.CreatedAt.Format("2006-01-02 15:04:05"), + ) + } + }, +} + +var CreatePreAuthKeyCmd = &cobra.Command{ + Use: "create NAMESPACE", + Short: "Creates a new preauthkey in the specified namespace", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) < 1 { + return fmt.Errorf("Missing parameters") + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + h, err := getHeadscaleApp() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + reusable, _ := cmd.Flags().GetBool("reusable") + + e, _ := cmd.Flags().GetString("expiration") + var expiration *time.Time + if e != "" { + duration, err := durafmt.ParseStringShort(e) + if err != nil { + log.Fatalf("Error parsing expiration: %s", err) + } + exp := time.Now().UTC().Add(duration.Duration()) + expiration = &exp + } + + _, err = h.CreatePreAuthKey(args[0], reusable, expiration) + if err != nil { + fmt.Println(err) + return + } + fmt.Printf("Ook.\n") + }, +} diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go new file mode 100644 index 0000000..c14956f --- /dev/null +++ b/cmd/headscale/cli/routes.go @@ -0,0 +1,53 @@ +package cli + +import ( + "fmt" + "log" + + "github.com/spf13/cobra" +) + +var ListRoutesCmd = &cobra.Command{ + Use: "list-routes NAMESPACE NODE", + Short: "List the routes exposed by this node", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) < 2 { + return fmt.Errorf("Missing parameters") + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + h, err := getHeadscaleApp() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + routes, err := h.GetNodeRoutes(args[0], args[1]) + if err != nil { + fmt.Println(err) + return + } + fmt.Println(routes) + }, +} + +var EnableRouteCmd = &cobra.Command{ + Use: "enable-route", + Short: "Allows exposing a route declared by this node to the rest of the nodes", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) < 3 { + return fmt.Errorf("Missing parameters") + } + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + h, err := getHeadscaleApp() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + err = h.EnableNodeRoute(args[0], args[1], args[2]) + if err != nil { + fmt.Println(err) + return + } + }, +} diff --git a/cmd/headscale/cli/server.go b/cmd/headscale/cli/server.go new file mode 100644 index 0000000..bdcf367 --- /dev/null +++ b/cmd/headscale/cli/server.go @@ -0,0 +1,25 @@ +package cli + +import ( + "log" + + "github.com/spf13/cobra" +) + +var ServeCmd = &cobra.Command{ + Use: "serve", + Short: "Launches the headscale server", + Args: func(cmd *cobra.Command, args []string) error { + return nil + }, + Run: func(cmd *cobra.Command, args []string) { + h, err := getHeadscaleApp() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + err = h.Serve() + if err != nil { + log.Fatalf("Error initializing: %s", err) + } + }, +} diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go new file mode 100644 index 0000000..8b2b274 --- /dev/null +++ b/cmd/headscale/cli/utils.go @@ -0,0 +1,74 @@ +package cli + +import ( + "io" + "log" + "os" + "path/filepath" + "strings" + + "github.com/juanfont/headscale" + "github.com/spf13/viper" + "gopkg.in/yaml.v2" + "tailscale.com/tailcfg" +) + +func absPath(path string) string { + // If a relative path is provided, prefix it with the the directory where + // the config file was found. + if (path != "") && !strings.HasPrefix(path, "/") { + dir, _ := filepath.Split(viper.ConfigFileUsed()) + if dir != "" { + path = dir + "/" + path + } + } + return path +} + +func getHeadscaleApp() (*headscale.Headscale, error) { + derpMap, err := loadDerpMap(absPath(viper.GetString("derp_map_path"))) + if err != nil { + log.Printf("Could not load DERP servers map file: %s", err) + } + + cfg := headscale.Config{ + ServerURL: viper.GetString("server_url"), + Addr: viper.GetString("listen_addr"), + PrivateKeyPath: absPath(viper.GetString("private_key_path")), + DerpMap: derpMap, + + DBhost: viper.GetString("db_host"), + DBport: viper.GetInt("db_port"), + DBname: viper.GetString("db_name"), + DBuser: viper.GetString("db_user"), + DBpass: viper.GetString("db_pass"), + + TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"), + TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")), + TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), + + TLSCertPath: absPath(viper.GetString("tls_cert_path")), + TLSKeyPath: absPath(viper.GetString("tls_key_path")), + } + + h, err := headscale.NewHeadscale(cfg) + if err != nil { + return nil, err + } + return h, nil +} + +func loadDerpMap(path string) (*tailcfg.DERPMap, error) { + derpFile, err := os.Open(path) + if err != nil { + return nil, err + } + defer derpFile.Close() + var derpMap tailcfg.DERPMap + b, err := io.ReadAll(derpFile) + if err != nil { + return nil, err + } + err = yaml.Unmarshal(b, &derpMap) + return &derpMap, err +} diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index 776ffb9..599914c 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -3,19 +3,13 @@ package main import ( "errors" "fmt" - "io" "log" "os" - "path/filepath" "strings" - "time" - "github.com/hako/durafmt" - "github.com/juanfont/headscale" + "github.com/juanfont/headscale/cmd/headscale/cli" "github.com/spf13/cobra" "github.com/spf13/viper" - "gopkg.in/yaml.v2" - "tailscale.com/tailcfg" ) var version = "dev" @@ -39,217 +33,6 @@ Juan Font Alonso - 2021 https://gitlab.com/juanfont/headscale`, } -var serveCmd = &cobra.Command{ - Use: "serve", - Short: "Launches the headscale server", - Args: func(cmd *cobra.Command, args []string) error { - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - err = h.Serve() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - }, -} - -var registerCmd = &cobra.Command{ - Use: "register machineID namespace", - Short: "Registers a machine to your network", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 2 { - return fmt.Errorf("Missing parameters") - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - err = h.RegisterMachine(args[0], args[1]) - if err != nil { - fmt.Printf("Error: %s", err) - return - } - fmt.Println("Ook.") - }, -} - -var namespaceCmd = &cobra.Command{ - Use: "namespace", - Short: "Manage the namespaces of Headscale", -} - -var createNamespaceCmd = &cobra.Command{ - Use: "create NAME", - Short: "Creates a new namespace", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return fmt.Errorf("Missing parameters") - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - _, err = h.CreateNamespace(args[0]) - if err != nil { - fmt.Println(err) - return - } - fmt.Printf("Ook.\n") - }, -} - -var listNamespacesCmd = &cobra.Command{ - Use: "list", - Short: "List all the namespaces", - Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - ns, err := h.ListNamespaces() - if err != nil { - fmt.Println(err) - return - } - fmt.Printf("ID\tName\n") - for _, n := range *ns { - fmt.Printf("%d\t%s\n", n.ID, n.Name) - } - }, -} - -var nodeCmd = &cobra.Command{ - Use: "node", - Short: "Manage the nodes of Headscale", -} - -var listRoutesCmd = &cobra.Command{ - Use: "list-routes NAMESPACE NODE", - Short: "List the routes exposed by this node", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 2 { - return fmt.Errorf("Missing parameters") - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - err = h.ListNodeRoutes(args[0], args[1]) - if err != nil { - fmt.Println(err) - return - } - }, -} - -var enableRouteCmd = &cobra.Command{ - Use: "enable-route", - Short: "Allows exposing a route declared by this node to the rest of the nodes", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 3 { - return fmt.Errorf("Missing parameters") - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - err = h.EnableNodeRoute(args[0], args[1], args[2]) - if err != nil { - fmt.Println(err) - return - } - }, -} - -var preauthkeysCmd = &cobra.Command{ - Use: "preauthkey", - Short: "Handle the preauthkeys in Headscale", -} - -var listPreAuthKeys = &cobra.Command{ - Use: "list NAMESPACE", - Short: "List the preauthkeys for this namespace", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return fmt.Errorf("Missing parameters") - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - keys, err := h.GetPreAuthKeys(args[0]) - if err != nil { - fmt.Println(err) - return - } - for _, k := range *keys { - fmt.Printf( - "key: %s, namespace: %s, reusable: %v, expiration: %s, created_at: %s\n", - k.Key, - k.Namespace.Name, - k.Reusable, - k.Expiration.Format("2006-01-02 15:04:05"), - k.CreatedAt.Format("2006-01-02 15:04:05"), - ) - } - }, -} - -var createPreAuthKeyCmd = &cobra.Command{ - Use: "create NAMESPACE", - Short: "Creates a new preauthkey in the specified namespace", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return fmt.Errorf("Missing parameters") - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) - } - reusable, _ := cmd.Flags().GetBool("reusable") - - e, _ := cmd.Flags().GetString("expiration") - var expiration *time.Time - if e != "" { - duration, err := durafmt.ParseStringShort(e) - if err != nil { - log.Fatalf("Error parsing expiration: %s", err) - } - exp := time.Now().UTC().Add(duration.Duration()) - expiration = &exp - } - - _, err = h.CreatePreAuthKey(args[0], reusable, expiration) - if err != nil { - fmt.Println(err) - return - } - fmt.Printf("Ook.\n") - }, -} - func loadConfig(path string) error { viper.SetConfigName("config") if path == "" { @@ -301,85 +84,25 @@ func main() { } headscaleCmd.AddCommand(versionCmd) - headscaleCmd.AddCommand(serveCmd) - headscaleCmd.AddCommand(registerCmd) - headscaleCmd.AddCommand(preauthkeysCmd) - headscaleCmd.AddCommand(namespaceCmd) - headscaleCmd.AddCommand(nodeCmd) + headscaleCmd.AddCommand(cli.ServeCmd) + headscaleCmd.AddCommand(cli.RegisterCmd) + headscaleCmd.AddCommand(cli.PreauthkeysCmd) + headscaleCmd.AddCommand(cli.NamespaceCmd) + headscaleCmd.AddCommand(cli.NodeCmd) - namespaceCmd.AddCommand(createNamespaceCmd) - namespaceCmd.AddCommand(listNamespacesCmd) + cli.NamespaceCmd.AddCommand(cli.CreateNamespaceCmd) + cli.NamespaceCmd.AddCommand(cli.ListNamespacesCmd) - nodeCmd.AddCommand(listRoutesCmd) - nodeCmd.AddCommand(enableRouteCmd) + cli.NodeCmd.AddCommand(cli.ListRoutesCmd) + cli.NodeCmd.AddCommand(cli.EnableRouteCmd) - preauthkeysCmd.AddCommand(listPreAuthKeys) - preauthkeysCmd.AddCommand(createPreAuthKeyCmd) - createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable") - createPreAuthKeyCmd.Flags().StringP("expiration", "e", "", "Human-readable expiration of the key (30m, 24h, 365d...)") + cli.PreauthkeysCmd.AddCommand(cli.ListPreAuthKeys) + cli.PreauthkeysCmd.AddCommand(cli.CreatePreAuthKeyCmd) + cli.CreatePreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable") + cli.CreatePreAuthKeyCmd.Flags().StringP("expiration", "e", "", "Human-readable expiration of the key (30m, 24h, 365d...)") if err := headscaleCmd.Execute(); err != nil { fmt.Println(err) os.Exit(-1) } } - -func absPath(path string) string { - // If a relative path is provided, prefix it with the the directory where - // the config file was found. - if (path != "") && !strings.HasPrefix(path, "/") { - dir, _ := filepath.Split(viper.ConfigFileUsed()) - if dir != "" { - path = dir + "/" + path - } - } - return path -} - -func getHeadscaleApp() (*headscale.Headscale, error) { - derpMap, err := loadDerpMap(absPath(viper.GetString("derp_map_path"))) - if err != nil { - log.Printf("Could not load DERP servers map file: %s", err) - } - - cfg := headscale.Config{ - ServerURL: viper.GetString("server_url"), - Addr: viper.GetString("listen_addr"), - PrivateKeyPath: absPath(viper.GetString("private_key_path")), - DerpMap: derpMap, - - DBhost: viper.GetString("db_host"), - DBport: viper.GetInt("db_port"), - DBname: viper.GetString("db_name"), - DBuser: viper.GetString("db_user"), - DBpass: viper.GetString("db_pass"), - - TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"), - TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")), - TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), - - TLSCertPath: absPath(viper.GetString("tls_cert_path")), - TLSKeyPath: absPath(viper.GetString("tls_key_path")), - } - - h, err := headscale.NewHeadscale(cfg) - if err != nil { - return nil, err - } - return h, nil -} - -func loadDerpMap(path string) (*tailcfg.DERPMap, error) { - derpFile, err := os.Open(path) - if err != nil { - return nil, err - } - defer derpFile.Close() - var derpMap tailcfg.DERPMap - b, err := io.ReadAll(derpFile) - if err != nil { - return nil, err - } - err = yaml.Unmarshal(b, &derpMap) - return &derpMap, err -} diff --git a/routes.go b/routes.go new file mode 100644 index 0000000..2894366 --- /dev/null +++ b/routes.go @@ -0,0 +1,69 @@ +package headscale + +import ( + "encoding/json" + "errors" + "log" + + "github.com/jinzhu/gorm/dialects/postgres" + "inet.af/netaddr" +) + +// GetNodeRoutes returns the subnet routes advertised by a node (identified by +// namespace and node name) +func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) { + m, err := h.GetMachine(namespace, nodeName) + if err != nil { + return nil, err + } + + hi, err := m.GetHostInfo() + if err != nil { + return nil, err + } + return &hi.RoutableIPs, nil +} + +// EnableNodeRoute enables a subnet route advertised by a node (identified by +// namespace and node name) +func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error { + m, err := h.GetMachine(namespace, nodeName) + if err != nil { + return err + } + hi, err := m.GetHostInfo() + if err != nil { + return err + } + route, err := netaddr.ParseIPPrefix(routeStr) + if err != nil { + return err + } + + for _, rIP := range hi.RoutableIPs { + if rIP == route { + db, err := h.db() + if err != nil { + log.Printf("Cannot open DB: %s", err) + return err + } + + routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest + m.EnabledRoutes = postgres.Jsonb{RawMessage: json.RawMessage(routes)} + db.Save(&m) + db.Close() + + peers, _ := h.getPeers(*m) + h.pollMu.Lock() + for _, p := range *peers { + if pUp, ok := h.clientsPolling[uint64(p.ID)]; ok { + pUp <- []byte{} + } + } + h.pollMu.Unlock() + return nil + } + } + + return errors.New("could not find routable range") +}