diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 9498bc6..6d77877 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -2,6 +2,7 @@ package db import ( "errors" + "fmt" "net/netip" "github.com/juanfont/headscale/hscontrol/policy" @@ -252,20 +253,20 @@ func DeleteRoute( func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) { routes, err := GetNodeRoutes(tx, node) if err != nil { - return nil, err + return nil, fmt.Errorf("getting node routes: %w", err) } var changed []types.NodeID for i := range routes { if err := tx.Unscoped().Delete(&routes[i]).Error; err != nil { - return nil, err + return nil, fmt.Errorf("deleting route(%d): %w", &routes[i].ID, err) } // TODO(kradalby): This is a bit too aggressive, we could probably // figure out which routes needs to be failed over rather than all. chn, err := failoverRouteTx(tx, isConnected, &routes[i]) if err != nil { - return changed, err + return changed, fmt.Errorf("failing over route after delete: %w", err) } if chn != nil { @@ -410,10 +411,8 @@ func FailoverRouteIfAvailable( isConnected types.NodeConnectedMap, node *types.Node, ) (*types.StateUpdate, error) { - log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Msgf("ROUTE DEBUG ENTERED FAILOVER") nodeRoutes, err := GetNodeRoutes(tx, node) if err != nil { - log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("nodeRoutes", nodeRoutes).Msgf("ROUTE DEBUG NO ROUTES") return nil, nil } @@ -421,34 +420,31 @@ func FailoverRouteIfAvailable( for _, nodeRoute := range nodeRoutes { routes, err := getRoutesByPrefix(tx, netip.Prefix(nodeRoute.Prefix)) if err != nil { - return nil, err + return nil, fmt.Errorf("getting routes by prefix: %w", err) } for _, route := range routes { if route.IsPrimary { // if we have a primary route, and the node is connected // nothing needs to be done. - log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG CHECKING IF ONLINE") if isConnected[route.Node.ID] { - log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG IS ONLINE") return nil, nil } - log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Uint64("route.node.id", route.Node.ID.Uint64()).Msgf("ROUTE DEBUG NOT ONLINE, FAILING OVER") // if not, we need to failover the route - changedIDs, err := failoverRouteTx(tx, isConnected, &route) - if err != nil { - return nil, err - } + failover := failoverRoute(isConnected, &route, routes) + if failover != nil { + failover.save(tx) + if err != nil { + return nil, fmt.Errorf("saving failover routes: %w", err) + } - if changedIDs != nil { - changedNodes = append(changedNodes, changedIDs...) + changedNodes = append(changedNodes, failover.old.Node.ID, failover.new.Node.ID) } } } } - log.Debug().Caller().Uint64("node.id", node.ID.Uint64()).Interface("changedNodes", changedNodes).Msgf("ROUTE DEBUG") if len(changedNodes) != 0 { return &types.StateUpdate{ Type: types.StatePeerChanged, @@ -490,7 +486,7 @@ func failoverRouteTx( routes, err := getRoutesByPrefix(tx, netip.Prefix(r.Prefix)) if err != nil { - return nil, err + return nil, fmt.Errorf("getting routes by prefix: %w", err) } fo := failoverRoute(isConnected, r, routes) @@ -498,18 +494,9 @@ func failoverRouteTx( return nil, nil } - err = tx.Save(fo.old).Error + err = fo.save(tx) if err != nil { - log.Error().Err(err).Msg("disabling old primary route") - - return nil, err - } - - err = tx.Save(fo.new).Error - if err != nil { - log.Error().Err(err).Msg("saving new primary route") - - return nil, err + return nil, fmt.Errorf("saving failover route: %w", err) } log.Trace(). @@ -525,6 +512,20 @@ type failover struct { new *types.Route } +func (f *failover) save(tx *gorm.DB) error { + err := tx.Save(f.old).Error + if err != nil { + return fmt.Errorf("saving old primary: %w", err) + } + + err = tx.Save(f.new).Error + if err != nil { + return fmt.Errorf("saving new primary: %w", err) + } + + return nil +} + func failoverRoute( isConnected types.NodeConnectedMap, routeToReplace *types.Route, @@ -603,13 +604,7 @@ func EnableAutoApprovedRoutes( routes, err := GetNodeAdvertisedRoutes(tx, node) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - log.Error(). - Caller(). - Err(err). - Str("node", node.Hostname). - Msg("Could not get advertised routes for node") - - return err + return fmt.Errorf("getting advertised routes for node(%s %d): %w", node.Hostname, node.ID, err) } log.Trace().Interface("routes", routes).Msg("routes for autoapproving") @@ -625,12 +620,7 @@ func EnableAutoApprovedRoutes( netip.Prefix(advertisedRoute.Prefix), ) if err != nil { - log.Err(err). - Str("advertisedRoute", advertisedRoute.String()). - Uint64("nodeId", node.ID.Uint64()). - Msg("Failed to resolve autoApprovers for advertised route") - - return err + return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err) } log.Trace(). @@ -647,11 +637,7 @@ func EnableAutoApprovedRoutes( // TODO(kradalby): figure out how to get this to depend on less stuff approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) if err != nil { - log.Err(err). - Str("alias", approvedAlias). - Msg("Failed to expand alias when processing autoApprovers policy") - - return err + return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err) } // approvedIPs should contain all of node's IPs if it matches the rule, so check for first @@ -665,12 +651,7 @@ func EnableAutoApprovedRoutes( for _, approvedRoute := range approvedRoutes { _, err := EnableRoute(tx, uint64(approvedRoute.ID)) if err != nil { - log.Err(err). - Str("approvedRoute", approvedRoute.String()). - Uint64("nodeId", node.ID.Uint64()). - Msg("Failed to enable approved route") - - return err + return fmt.Errorf("enabling approved route(%d): %w", approvedRoute.ID, err) } }