diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go index 7201086..4a935cf 100644 --- a/cmd/headscale/cli/routes.go +++ b/cmd/headscale/cli/routes.go @@ -1,26 +1,37 @@ package cli import ( + "context" "fmt" "log" - "strings" + "strconv" + "time" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/pterm/pterm" "github.com/spf13/cobra" ) func init() { rootCmd.AddCommand(routesCmd) - routesCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace") - err := routesCmd.MarkPersistentFlagRequired("namespace") + + listRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") + err := listRoutesCmd.MarkFlagRequired("identifier") + if err != nil { + log.Fatalf(err.Error()) + } + routesCmd.AddCommand(listRoutesCmd) + + enableRouteCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable") + enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") + err = enableRouteCmd.MarkFlagRequired("identifier") if err != nil { log.Fatalf(err.Error()) } - enableRouteCmd.Flags().BoolP("all", "a", false, "Enable all routes advertised by the node") - - routesCmd.AddCommand(listRoutesCmd) routesCmd.AddCommand(enableRouteCmd) + + nodeCmd.AddCommand(routesCmd) } var routesCmd = &cobra.Command{ @@ -29,119 +40,128 @@ var routesCmd = &cobra.Command{ } var listRoutesCmd = &cobra.Command{ - Use: "list NODE", - Short: "List the routes exposed by this node", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return fmt.Errorf("Missing parameters") - } - return nil - }, + Use: "list", + Short: "List routes advertised and enabled by a given node", Run: func(cmd *cobra.Command, args []string) { - n, err := cmd.Flags().GetString("namespace") - if err != nil { - log.Fatalf("Error getting namespace: %s", err) - } - o, _ := cmd.Flags().GetString("output") + output, _ := cmd.Flags().GetString("output") - h, err := getHeadscaleApp() + machineId, err := cmd.Flags().GetUint64("identifier") if err != nil { - log.Fatalf("Error initializing: %s", err) - } - - availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0]) - if err != nil { - fmt.Println(err) + ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output) return } - if strings.HasPrefix(o, "json") { - // TODO: Add enable/disabled information to this interface - JsonOutput(availableRoutes, err, o) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + client, conn := getHeadscaleGRPCClient(ctx) + defer conn.Close() + + request := &v1.GetMachineRouteRequest{ + MachineId: machineId, + } + + response, err := client.GetMachineRoute(ctx, request) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", err), output) return } - d := h.RoutesToPtables(n, args[0], *availableRoutes) + if output != "" { + SuccessOutput(response.Routes, "", output) + return + } + + d := routesToPtables(response.Routes) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) + return + } err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() if err != nil { - log.Fatal(err) + ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) + return } }, } var enableRouteCmd = &cobra.Command{ - Use: "enable node-name route", - Short: "Allows exposing a route declared by this node to the rest of the nodes", - Args: func(cmd *cobra.Command, args []string) error { - all, err := cmd.Flags().GetBool("all") - if err != nil { - log.Fatalf("Error getting namespace: %s", err) - } - - if all { - if len(args) < 1 { - return fmt.Errorf("Missing parameters") - } - return nil - } else { - if len(args) < 2 { - return fmt.Errorf("Missing parameters") - } - return nil - } - }, + Use: "enable", + Short: "Set the enabled routes for a given node", + Long: `This command will take a list of routes that will _replace_ +the current set of routes on a given node. +If you would like to disable a route, simply run the command again, but +omit the route you do not want to enable. + `, Run: func(cmd *cobra.Command, args []string) { - n, err := cmd.Flags().GetString("namespace") + output, _ := cmd.Flags().GetString("output") + machineId, err := cmd.Flags().GetUint64("identifier") if err != nil { - log.Fatalf("Error getting namespace: %s", err) + ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output) + return } - o, _ := cmd.Flags().GetString("output") - - all, err := cmd.Flags().GetBool("all") + routes, err := cmd.Flags().GetStringSlice("route") if err != nil { - log.Fatalf("Error getting namespace: %s", err) + ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output) + return } - h, err := getHeadscaleApp() - if err != nil { - log.Fatalf("Error initializing: %s", err) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + client, conn := getHeadscaleGRPCClient(ctx) + defer conn.Close() + + request := &v1.EnableMachineRoutesRequest{ + MachineId: machineId, + Routes: routes, } - if all { - availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0]) - if err != nil { - fmt.Println(err) - return - } + response, err := client.EnableMachineRoutes(ctx, request) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", err), output) + return + } - for _, availableRoute := range *availableRoutes { - err = h.EnableNodeRoute(n, args[0], availableRoute.String()) - if err != nil { - fmt.Println(err) - return - } + if output != "" { + SuccessOutput(response.Routes, "", output) + return + } - if strings.HasPrefix(o, "json") { - JsonOutput(availableRoute, err, o) - } else { - fmt.Printf("Enabled route %s\n", availableRoute) - } - } - } else { - err = h.EnableNodeRoute(n, args[0], args[1]) + d := routesToPtables(response.Routes) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) + return + } - if strings.HasPrefix(o, "json") { - JsonOutput(args[1], err, o) - return - } - - if err != nil { - fmt.Println(err) - return - } - fmt.Printf("Enabled route %s\n", args[1]) + err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) + return } }, } + +// routesToPtables converts the list of routes to a nice table +func routesToPtables(routes *v1.Routes) pterm.TableData { + d := pterm.TableData{{"Route", "Enabled"}} + + for _, route := range routes.GetAdvertisedRoutes() { + enabled := isStringInSlice(routes.EnabledRoutes, route) + + d = append(d, []string{route, strconv.FormatBool(enabled)}) + } + return d +} + +func isStringInSlice(strs []string, s string) bool { + for _, s2 := range strs { + if s == s2 { + return true + } + } + + return false +}