diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go index 105175a..3f0c208 100644 --- a/cmd/headscale/cli/routes.go +++ b/cmd/headscale/cli/routes.go @@ -13,27 +13,22 @@ import ( func init() { rootCmd.AddCommand(routesCmd) - listRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - err := listRoutesCmd.MarkFlagRequired("identifier") - if err != nil { - log.Fatalf(err.Error()) - } routesCmd.AddCommand(listRoutesCmd) - enableRouteCmd.Flags(). - StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable") - enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - enableRouteCmd.Flags().BoolP("all", "a", false, "All routes from host") - - err = enableRouteCmd.MarkFlagRequired("identifier") + enableRouteCmd.Flags().Uint64P("route", "r", 0, "Route identifier (ID)") + err := enableRouteCmd.MarkFlagRequired("route") if err != nil { log.Fatalf(err.Error()) } - routesCmd.AddCommand(enableRouteCmd) - nodeCmd.AddCommand(routesCmd) + disableRouteCmd.Flags().Uint64P("route", "r", 0, "Route identifier (ID)") + err = disableRouteCmd.MarkFlagRequired("route") + if err != nil { + log.Fatalf(err.Error()) + } + routesCmd.AddCommand(disableRouteCmd) } var routesCmd = &cobra.Command{ @@ -44,7 +39,7 @@ var routesCmd = &cobra.Command{ var listRoutesCmd = &cobra.Command{ Use: "list", - Short: "List routes advertised and enabled by a given node", + Short: "List all routes", Aliases: []string{"ls", "show"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") @@ -64,28 +59,39 @@ var listRoutesCmd = &cobra.Command{ defer cancel() defer conn.Close() - request := &v1.GetMachineRouteRequest{ - MachineId: machineID, + var routes []*v1.Route + + if machineID == 0 { + response, err := client.GetRoutes(ctx, &v1.GetRoutesRequest{}) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), + output, + ) + + return + } + + routes = response.Routes + } else { + response, err := client.GetMachineRoutes(ctx, &v1.GetMachineRoutesRequest{ + MachineId: machineID, + }) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot get routes for machine %d: %s", machineID, status.Convert(err).Message()), + output, + ) + + return + } + + routes = response.Routes } - response, err := client.GetMachineRoute(ctx, request) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), - output, - ) - - return - } - - if output != "" { - SuccessOutput(response.Routes, "", output) - - return - } - - tableData := routesToPtables(response.Routes) + tableData := routesToPtables(routes) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) @@ -107,16 +113,12 @@ var listRoutesCmd = &cobra.Command{ var enableRouteCmd = &cobra.Command{ Use: "enable", - Short: "Set the enabled routes for a given node", - Long: `This command will take a list of routes that will _replace_ -the current set of routes on a given node. -If you would like to disable a route, simply run the command again, but -omit the route you do not want to enable. - `, + Short: "Set a route as enabled", + Long: `This command will make as enabled a given route.`, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - machineID, err := cmd.Flags().GetUint64("identifier") + routeID, err := cmd.Flags().GetUint64("route") if err != nil { ErrorOutput( err, @@ -131,52 +133,13 @@ omit the route you do not want to enable. defer cancel() defer conn.Close() - var routes []string - - isAll, _ := cmd.Flags().GetBool("all") - if isAll { - response, err := client.GetMachineRoute(ctx, &v1.GetMachineRouteRequest{ - MachineId: machineID, - }) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf( - "Cannot get machine routes: %s\n", - status.Convert(err).Message(), - ), - output, - ) - - return - } - routes = response.GetRoutes().GetAdvertisedRoutes() - } else { - routes, err = cmd.Flags().GetStringSlice("route") - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error getting routes from flag: %s", err), - output, - ) - - return - } - } - - request := &v1.EnableMachineRoutesRequest{ - MachineId: machineID, - Routes: routes, - } - - response, err := client.EnableMachineRoutes(ctx, request) + response, err := client.EnableRoute(ctx, &v1.EnableRouteRequest{ + RouteId: routeID, + }) if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot register machine: %s\n", - status.Convert(err).Message(), - ), + fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()), output, ) @@ -184,50 +147,71 @@ omit the route you do not want to enable. } if output != "" { - SuccessOutput(response.Routes, "", output) + SuccessOutput(response, "", output) return } + }, +} - tableData := routesToPtables(response.Routes) - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) +var disableRouteCmd = &cobra.Command{ + Use: "disable", + Short: "Set as disabled a given route", + Long: `This command will make as disabled a given route.`, + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") - return - } - - err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() + routeID, err := cmd.Flags().GetUint64("route") if err != nil { ErrorOutput( err, - fmt.Sprintf("Failed to render pterm table: %s", err), + fmt.Sprintf("Error getting machine id from flag: %s", err), output, ) return } + + ctx, client, conn, cancel := getHeadscaleCLIClient() + defer cancel() + defer conn.Close() + + response, err := client.DisableRoute(ctx, &v1.DisableRouteRequest{ + RouteId: routeID, + }) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()), + output, + ) + + return + } + + if output != "" { + SuccessOutput(response, "", output) + + return + } }, } // routesToPtables converts the list of routes to a nice table. -func routesToPtables(routes *v1.Routes) pterm.TableData { - tableData := pterm.TableData{{"Route", "Enabled"}} +func routesToPtables(routes []*v1.Route) pterm.TableData { + tableData := pterm.TableData{{"ID", "Machine", "Prefix", "Advertised", "Enabled", "Primary"}} - for _, route := range routes.GetAdvertisedRoutes() { - enabled := isStringInSlice(routes.EnabledRoutes, route) - - tableData = append(tableData, []string{route, strconv.FormatBool(enabled)}) + for _, route := range routes { + tableData = append(tableData, + []string{ + strconv.FormatUint(route.Id, 10), + route.Machine.GivenName, + route.Prefix, + strconv.FormatBool(route.Advertised), + strconv.FormatBool(route.Enabled), + strconv.FormatBool(route.IsPrimary), + }) } return tableData } - -func isStringInSlice(strs []string, s string) bool { - for _, s2 := range strs { - if s == s2 { - return true - } - } - - return false -} diff --git a/integration_cli_test.go b/integration_cli_test.go index 5a5cb0c..62df6b3 100644 --- a/integration_cli_test.go +++ b/integration_cli_test.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "os" + "strconv" "testing" "time" @@ -1305,24 +1306,22 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() { "list", "--output", "json", - "--identifier", - "0", }, []string{}, ) assert.Nil(s.T(), err) - var listAll v1.Routes - err = json.Unmarshal([]byte(listAllResult), &listAll) + var routes []v1.Route + err = json.Unmarshal([]byte(listAllResult), &routes) assert.Nil(s.T(), err) - assert.Len(s.T(), listAll.AdvertisedRoutes, 2) - assert.Contains(s.T(), listAll.AdvertisedRoutes, "10.0.0.0/8") - assert.Contains(s.T(), listAll.AdvertisedRoutes, "192.168.1.0/24") + assert.Len(s.T(), routes, 2) + assert.Equal(s.T(), routes[0].Enabled, false) + assert.Equal(s.T(), routes[1].Enabled, false) - assert.Empty(s.T(), listAll.EnabledRoutes) + routeIDToEnable := routes[1].Id - enableTwoRoutesResult, _, err := ExecuteCommand( + _, _, err = ExecuteCommand( &s.headscale, []string{ "headscale", @@ -1330,110 +1329,86 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() { "enable", "--output", "json", - "--identifier", - "0", "--route", - "10.0.0.0/8", - "--route", - "192.168.1.0/24", + strconv.FormatUint(routeIDToEnable, 10), }, []string{}, ) assert.Nil(s.T(), err) - var enableTwoRoutes v1.Routes - err = json.Unmarshal([]byte(enableTwoRoutesResult), &enableTwoRoutes) + listAllResult, _, err = ExecuteCommand( + &s.headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + []string{}, + ) assert.Nil(s.T(), err) - assert.Len(s.T(), enableTwoRoutes.AdvertisedRoutes, 2) - assert.Contains(s.T(), enableTwoRoutes.AdvertisedRoutes, "10.0.0.0/8") - assert.Contains(s.T(), enableTwoRoutes.AdvertisedRoutes, "192.168.1.0/24") + assert.Nil(s.T(), err) - assert.Len(s.T(), enableTwoRoutes.EnabledRoutes, 2) - assert.Contains(s.T(), enableTwoRoutes.EnabledRoutes, "10.0.0.0/8") - assert.Contains(s.T(), enableTwoRoutes.EnabledRoutes, "192.168.1.0/24") + err = json.Unmarshal([]byte(listAllResult), &routes) + assert.Nil(s.T(), err) + + assert.Len(s.T(), routes, 2) + + for _, route := range routes { + if route.Id == routeIDToEnable { + assert.Equal(s.T(), route.Enabled, true) + assert.Equal(s.T(), route.IsPrimary, true) + } else { + assert.Equal(s.T(), route.Enabled, false) + } + } // Enable only one route, effectively disabling one of the routes - enableOneRouteResult, _, err := ExecuteCommand( + _, _, err = ExecuteCommand( &s.headscale, []string{ "headscale", "routes", - "enable", + "disable", "--output", "json", - "--identifier", - "0", "--route", - "10.0.0.0/8", + strconv.FormatUint(routeIDToEnable, 10), }, []string{}, ) assert.Nil(s.T(), err) - var enableOneRoute v1.Routes - err = json.Unmarshal([]byte(enableOneRouteResult), &enableOneRoute) - assert.Nil(s.T(), err) - - assert.Len(s.T(), enableOneRoute.AdvertisedRoutes, 2) - assert.Contains(s.T(), enableOneRoute.AdvertisedRoutes, "10.0.0.0/8") - assert.Contains(s.T(), enableOneRoute.AdvertisedRoutes, "192.168.1.0/24") - - assert.Len(s.T(), enableOneRoute.EnabledRoutes, 1) - assert.Contains(s.T(), enableOneRoute.EnabledRoutes, "10.0.0.0/8") - - // Enable only one route, effectively disabling one of the routes - failEnableNonAdvertisedRoute, _, err := ExecuteCommand( + listAllResult, _, err = ExecuteCommand( &s.headscale, []string{ "headscale", "routes", - "enable", + "list", "--output", "json", - "--identifier", - "0", - "--route", - "11.0.0.0/8", }, []string{}, ) assert.Nil(s.T(), err) - assert.Contains( - s.T(), - string(failEnableNonAdvertisedRoute), - "route (route-machine) is not available on node", - ) - - // Enable all routes on host - enableAllRouteResult, _, err := ExecuteCommand( - &s.headscale, - []string{ - "headscale", - "routes", - "enable", - "--output", - "json", - "--identifier", - "0", - "--all", - }, - []string{}, - ) assert.Nil(s.T(), err) - var enableAllRoute v1.Routes - err = json.Unmarshal([]byte(enableAllRouteResult), &enableAllRoute) + err = json.Unmarshal([]byte(listAllResult), &routes) assert.Nil(s.T(), err) - assert.Len(s.T(), enableAllRoute.AdvertisedRoutes, 2) - assert.Contains(s.T(), enableAllRoute.AdvertisedRoutes, "10.0.0.0/8") - assert.Contains(s.T(), enableAllRoute.AdvertisedRoutes, "192.168.1.0/24") + assert.Len(s.T(), routes, 2) - assert.Len(s.T(), enableAllRoute.EnabledRoutes, 2) - assert.Contains(s.T(), enableAllRoute.EnabledRoutes, "10.0.0.0/8") - assert.Contains(s.T(), enableAllRoute.EnabledRoutes, "192.168.1.0/24") + for _, route := range routes { + if route.Id == routeIDToEnable { + assert.Equal(s.T(), route.Enabled, false) + assert.Equal(s.T(), route.IsPrimary, false) + } else { + assert.Equal(s.T(), route.Enabled, false) + } + } } func (s *IntegrationCLITestSuite) TestApiKeyCommand() {