Merge pull request #92 from kradalby/enhance-route-command

Enhance route command with ptables and multiple routes
This commit is contained in:
Juan Font 2021-08-22 12:48:30 +02:00 committed by GitHub
commit ca8d814918
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 256 additions and 46 deletions

View file

@ -5,6 +5,7 @@ import (
"log" "log"
"strings" "strings"
"github.com/pterm/pterm"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -15,6 +16,9 @@ func init() {
if err != nil { if err != nil {
log.Fatalf(err.Error()) log.Fatalf(err.Error())
} }
enableRouteCmd.Flags().BoolP("all", "a", false, "Enable all routes advertised by the node")
routesCmd.AddCommand(listRoutesCmd) routesCmd.AddCommand(listRoutesCmd)
routesCmd.AddCommand(enableRouteCmd) routesCmd.AddCommand(enableRouteCmd)
} }
@ -44,19 +48,25 @@ var listRoutesCmd = &cobra.Command{
if err != nil { if err != nil {
log.Fatalf("Error initializing: %s", err) log.Fatalf("Error initializing: %s", err)
} }
routes, err := h.GetNodeRoutes(n, args[0])
if strings.HasPrefix(o, "json") {
JsonOutput(routes, err, o)
return
}
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(routes) if strings.HasPrefix(o, "json") {
// TODO: Add enable/disabled information to this interface
JsonOutput(availableRoutes, err, o)
return
}
d := h.RoutesToPtables(n, args[0], *availableRoutes)
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
log.Fatal(err)
}
}, },
} }
@ -64,25 +74,66 @@ var enableRouteCmd = &cobra.Command{
Use: "enable node-name route", Use: "enable node-name route",
Short: "Allows exposing a route declared by this node to the rest of the nodes", Short: "Allows exposing a route declared by this node to the rest of the nodes",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
all, err := cmd.Flags().GetBool("all")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
if all {
if len(args) < 1 {
return fmt.Errorf("Missing parameters")
}
return nil
} else {
if len(args) < 2 { if len(args) < 2 {
return fmt.Errorf("Missing parameters") return fmt.Errorf("Missing parameters")
} }
return nil return nil
}
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
n, err := cmd.Flags().GetString("namespace") n, err := cmd.Flags().GetString("namespace")
if err != nil { if err != nil {
log.Fatalf("Error getting namespace: %s", err) log.Fatalf("Error getting namespace: %s", err)
} }
o, _ := cmd.Flags().GetString("output") o, _ := cmd.Flags().GetString("output")
all, err := cmd.Flags().GetBool("all")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
h, err := getHeadscaleApp() h, err := getHeadscaleApp()
if err != nil { if err != nil {
log.Fatalf("Error initializing: %s", err) log.Fatalf("Error initializing: %s", err)
} }
route, err := h.EnableNodeRoute(n, args[0], args[1])
if all {
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
if err != nil {
fmt.Println(err)
return
}
for _, availableRoute := range *availableRoutes {
err = h.EnableNodeRoute(n, args[0], availableRoute.String())
if err != nil {
fmt.Println(err)
return
}
if strings.HasPrefix(o, "json") { if strings.HasPrefix(o, "json") {
JsonOutput(route, err, o) JsonOutput(availableRoute, err, o)
} else {
fmt.Printf("Enabled route %s\n", availableRoute)
}
}
} else {
err = h.EnableNodeRoute(n, args[0], args[1])
if strings.HasPrefix(o, "json") {
JsonOutput(args[1], err, o)
return return
} }
@ -90,6 +141,7 @@ var enableRouteCmd = &cobra.Command{
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Printf("Enabled route %s\n", route) fmt.Printf("Enabled route %s\n", args[1])
}
}, },
} }

113
routes.go
View file

@ -2,55 +2,140 @@ package headscale
import ( import (
"encoding/json" "encoding/json"
"errors" "fmt"
"strconv"
"github.com/pterm/pterm"
"gorm.io/datatypes" "gorm.io/datatypes"
"inet.af/netaddr" "inet.af/netaddr"
) )
// GetNodeRoutes returns the subnet routes advertised by a node (identified by // GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name) // namespace and node name)
func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) { func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName) m, err := h.GetMachine(namespace, nodeName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hi, err := m.GetHostInfo() hostInfo, err := m.GetHostInfo()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &hi.RoutableIPs, nil return &hostInfo.RoutableIPs, nil
} }
// EnableNodeRoute enables a subnet route advertised by a node (identified by // GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name) // namespace and node name)
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) (*netaddr.IPPrefix, error) { func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName) m, err := h.GetMachine(namespace, nodeName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hi, err := m.GetHostInfo()
data, err := m.EnabledRoutes.MarshalJSON()
if err != nil { if err != nil {
return nil, err return nil, err
} }
routesStr := []string{}
err = json.Unmarshal(data, &routesStr)
if err != nil {
return nil, err
}
routes := make([]netaddr.IPPrefix, len(routesStr))
for index, routeStr := range routesStr {
route, err := netaddr.ParseIPPrefix(routeStr) route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
routes[index] = route
}
return routes, nil
}
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
if err != nil {
return false
}
for _, enabledRoute := range enabledRoutes {
if route == enabledRoute {
return true
}
}
return false
}
// EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name)
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return err
}
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return err
}
availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName)
if err != nil {
return err
}
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
if err != nil {
return err
}
available := false
for _, availableRoute := range *availableRoutes {
// If the route is available, and not yet enabled, add it to the new routing table
if route == availableRoute {
available = true
if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) {
enabledRoutes = append(enabledRoutes, route)
}
}
}
if !available {
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
}
routes, err := json.Marshal(enabledRoutes)
if err != nil {
return err
}
for _, rIP := range hi.RoutableIPs {
if rIP == route {
routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
m.EnabledRoutes = datatypes.JSON(routes) m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m) h.db.Save(&m)
err = h.RequestMapUpdates(m.NamespaceID) err = h.RequestMapUpdates(m.NamespaceID)
if err != nil { if err != nil {
return nil, err return err
} }
return &rIP, nil
return nil
} }
func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}}
for _, route := range availableRoutes {
enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String())
d = append(d, []string{route.String(), strconv.FormatBool(enabled)})
} }
return nil, errors.New("could not find routable range") return d
} }

View file

@ -16,7 +16,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine") _, err = h.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix("10.0.0.0/24") route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
MachineKey: "foo", MachineKey: "foo",
NodeKey: "bar", NodeKey: "bar",
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "test_get_route_machine",
NamespaceID: n.ID, NamespaceID: n.ID,
Registered: true, Registered: true,
RegisterMethod: "authKey", RegisterMethod: "authKey",
@ -42,14 +42,87 @@ func (s *Suite) TestGetRoutes(c *check.C) {
} }
h.db.Save(&m) h.db.Save(&m)
r, err := h.GetNodeRoutes("test", "testmachine") r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(*r), check.Equals, 1) c.Assert(len(*r), check.Equals, 1)
_, err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24") err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
_, err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24") err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetEnableRoutes(c *check.C) {
n, err := h.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
route2, err := netaddr.ParseIPPrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hi := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route, route2},
}
hostinfo, err := json.Marshal(hi)
c.Assert(err, check.IsNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "test_enable_route_machine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo),
}
h.db.Save(&m)
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(*availableRoutes), check.Equals, 2)
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes), check.Equals, 0)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
// Adding it twice will just let it pass through
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes3), check.Equals, 2)
} }