From f7f472ae07000d5a3d255c1421840c0685ba640f Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Fri, 26 May 2023 11:26:34 +0100 Subject: [PATCH] introduce mapper package The mapper package contains functions related to creating and marshalling reponses to machines. Signed-off-by: Kristoffer Dalby --- flake.nix | 313 +++++++++++---------- hscontrol/api_common.go | 115 -------- hscontrol/app.go | 4 +- hscontrol/db/machine.go | 71 +---- hscontrol/db/machine_test.go | 4 +- hscontrol/db/users.go | 31 --- hscontrol/db/users_test.go | 148 ---------- hscontrol/derp_server.go | 4 +- hscontrol/dns.go | 69 ----- hscontrol/dns_test.go | 295 -------------------- hscontrol/mapper/mapper.go | 418 +++++++++++++++++++++++++++++ hscontrol/mapper/mapper_test.go | 131 +++++++++ hscontrol/mapper/suite_test.go | 15 ++ hscontrol/protocol_common.go | 19 +- hscontrol/protocol_common_poll.go | 46 +++- hscontrol/protocol_common_utils.go | 156 ----------- hscontrol/protocol_legacy.go | 2 +- hscontrol/protocol_legacy_poll.go | 2 +- 18 files changed, 780 insertions(+), 1063 deletions(-) delete mode 100644 hscontrol/api_common.go create mode 100644 hscontrol/mapper/mapper.go create mode 100644 hscontrol/mapper/mapper_test.go create mode 100644 hscontrol/mapper/suite_test.go diff --git a/flake.nix b/flake.nix index 0363238..929e62c 100644 --- a/flake.nix +++ b/flake.nix @@ -6,177 +6,172 @@ flake-utils.url = "github:numtide/flake-utils"; }; - outputs = - { self - , nixpkgs - , flake-utils - , ... - }: - let - headscaleVersion = - if (self ? shortRev) - then self.shortRev - else "dev"; - in + outputs = { + self, + nixpkgs, + flake-utils, + ... + }: let + headscaleVersion = + if (self ? shortRev) + then self.shortRev + else "dev"; + in { - overlay = _: prev: - let - pkgs = nixpkgs.legacyPackages.${prev.system}; - in - rec { - headscale = pkgs.buildGo120Module rec { - pname = "headscale"; - version = headscaleVersion; - src = pkgs.lib.cleanSource self; + overlay = _: prev: let + pkgs = nixpkgs.legacyPackages.${prev.system}; + in rec { + headscale = pkgs.buildGo120Module rec { + pname = "headscale"; + version = headscaleVersion; + src = pkgs.lib.cleanSource self; - tags = [ "ts2019" ]; + tags = ["ts2019"]; - # Only run unit tests when testing a build - checkFlags = [ "-short" ]; + # Only run unit tests when testing a build + checkFlags = ["-short"]; - # When updating go.mod or go.sum, a new sha will need to be calculated, - # update this if you have a mismatch after doing a change to thos files. - vendorSha256 = "sha256-IOkbbFtE6+tNKnglE/8ZuNxhPSnloqM2sLgTvagMmnc="; + # When updating go.mod or go.sum, a new sha will need to be calculated, + # update this if you have a mismatch after doing a change to thos files. + vendorSha256 = "sha256-ui0V7a8bAAm5B7zfN9g2pWTyMpudtm10RgYWQwC6kcA="; - ldflags = [ "-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}" ]; - }; - - golines = pkgs.buildGoModule rec { - pname = "golines"; - version = "0.11.0"; - - src = pkgs.fetchFromGitHub { - owner = "segmentio"; - repo = "golines"; - rev = "v${version}"; - sha256 = "sha256-2K9KAg8iSubiTbujyFGN3yggrL+EDyeUCs9OOta/19A="; - }; - - vendorSha256 = "sha256-rxYuzn4ezAxaeDhxd8qdOzt+CKYIh03A9zKNdzILq18="; - - nativeBuildInputs = [ pkgs.installShellFiles ]; - }; - - golangci-lint = prev.golangci-lint.override { - # Override https://github.com/NixOS/nixpkgs/pull/166801 which changed this - # to buildGo118Module because it does not build on Darwin. - inherit (prev) buildGoModule; - }; - - protoc-gen-grpc-gateway = pkgs.buildGoModule rec { - pname = "grpc-gateway"; - version = "2.14.0"; - - src = pkgs.fetchFromGitHub { - owner = "grpc-ecosystem"; - repo = "grpc-gateway"; - rev = "v${version}"; - sha256 = "sha256-lnNdsDCpeSHtl2lC1IhUw11t3cnGF+37qSM7HDvKLls="; - }; - - vendorSha256 = "sha256-dGdnDuRbwg8fU7uB5GaHEWa/zI3w06onqjturvooJQA="; - - nativeBuildInputs = [ pkgs.installShellFiles ]; - - subPackages = [ "protoc-gen-grpc-gateway" "protoc-gen-openapiv2" ]; - }; + ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"]; }; + + golines = pkgs.buildGoModule rec { + pname = "golines"; + version = "0.11.0"; + + src = pkgs.fetchFromGitHub { + owner = "segmentio"; + repo = "golines"; + rev = "v${version}"; + sha256 = "sha256-2K9KAg8iSubiTbujyFGN3yggrL+EDyeUCs9OOta/19A="; + }; + + vendorSha256 = "sha256-rxYuzn4ezAxaeDhxd8qdOzt+CKYIh03A9zKNdzILq18="; + + nativeBuildInputs = [pkgs.installShellFiles]; + }; + + golangci-lint = prev.golangci-lint.override { + # Override https://github.com/NixOS/nixpkgs/pull/166801 which changed this + # to buildGo118Module because it does not build on Darwin. + inherit (prev) buildGoModule; + }; + + protoc-gen-grpc-gateway = pkgs.buildGoModule rec { + pname = "grpc-gateway"; + version = "2.14.0"; + + src = pkgs.fetchFromGitHub { + owner = "grpc-ecosystem"; + repo = "grpc-gateway"; + rev = "v${version}"; + sha256 = "sha256-lnNdsDCpeSHtl2lC1IhUw11t3cnGF+37qSM7HDvKLls="; + }; + + vendorSha256 = "sha256-dGdnDuRbwg8fU7uB5GaHEWa/zI3w06onqjturvooJQA="; + + nativeBuildInputs = [pkgs.installShellFiles]; + + subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"]; + }; + }; } // flake-utils.lib.eachDefaultSystem - (system: - let - pkgs = import nixpkgs { - overlays = [ self.overlay ]; - inherit system; - }; - buildDeps = with pkgs; [ git go_1_20 gnumake ]; - devDeps = with pkgs; - buildDeps - ++ [ - golangci-lint - golines - nodePackages.prettier - goreleaser - nfpm - gotestsum - gotests + (system: let + pkgs = import nixpkgs { + overlays = [self.overlay]; + inherit system; + }; + buildDeps = with pkgs; [git go_1_20 gnumake]; + devDeps = with pkgs; + buildDeps + ++ [ + golangci-lint + golines + nodePackages.prettier + goreleaser + nfpm + gotestsum + gotests - # 'dot' is needed for pprof graphs - # go tool pprof -http=: - graphviz + # 'dot' is needed for pprof graphs + # go tool pprof -http=: + graphviz - # Protobuf dependencies - protobuf - protoc-gen-go - protoc-gen-go-grpc - protoc-gen-grpc-gateway - buf - clang-tools # clang-format - ]; + # Protobuf dependencies + protobuf + protoc-gen-go + protoc-gen-go-grpc + protoc-gen-grpc-gateway + buf + clang-tools # clang-format + ]; - # Add entry to build a docker image with headscale - # caveat: only works on Linux - # - # Usage: - # nix build .#headscale-docker - # docker load < result - headscale-docker = pkgs.dockerTools.buildLayeredImage { - name = "headscale"; - tag = headscaleVersion; - contents = [ pkgs.headscale ]; - config.Entrypoint = [ (pkgs.headscale + "/bin/headscale") ]; - }; - in - rec { - # `nix develop` - devShell = pkgs.mkShell { - buildInputs = devDeps; + # Add entry to build a docker image with headscale + # caveat: only works on Linux + # + # Usage: + # nix build .#headscale-docker + # docker load < result + headscale-docker = pkgs.dockerTools.buildLayeredImage { + name = "headscale"; + tag = headscaleVersion; + contents = [pkgs.headscale]; + config.Entrypoint = [(pkgs.headscale + "/bin/headscale")]; + }; + in rec { + # `nix develop` + devShell = pkgs.mkShell { + buildInputs = devDeps; - shellHook = '' - export GOFLAGS=-tags="ts2019" - export PATH="$PWD/result/bin:$PATH" + shellHook = '' + export GOFLAGS=-tags="ts2019" + export PATH="$PWD/result/bin:$PATH" - mkdir -p ./ignored - export HEADSCALE_PRIVATE_KEY_PATH="./ignored/private.key" - export HEADSCALE_NOISE_PRIVATE_KEY_PATH="./ignored/noise_private.key" - export HEADSCALE_DB_PATH="./ignored/db.sqlite" - export HEADSCALE_TLS_LETSENCRYPT_CACHE_DIR="./ignored/cache" - export HEADSCALE_UNIX_SOCKET="./ignored/headscale.sock" + mkdir -p ./ignored + export HEADSCALE_PRIVATE_KEY_PATH="./ignored/private.key" + export HEADSCALE_NOISE_PRIVATE_KEY_PATH="./ignored/noise_private.key" + export HEADSCALE_DB_PATH="./ignored/db.sqlite" + export HEADSCALE_TLS_LETSENCRYPT_CACHE_DIR="./ignored/cache" + export HEADSCALE_UNIX_SOCKET="./ignored/headscale.sock" + ''; + }; + + # `nix build` + packages = with pkgs; { + inherit headscale; + inherit headscale-docker; + }; + defaultPackage = pkgs.headscale; + + # `nix run` + apps.headscale = flake-utils.lib.mkApp { + drv = packages.headscale; + }; + apps.default = apps.headscale; + + checks = { + format = + pkgs.runCommand "check-format" + { + buildInputs = with pkgs; [ + gnumake + nixpkgs-fmt + golangci-lint + nodePackages.prettier + golines + clang-tools + ]; + } '' + ${pkgs.nixpkgs-fmt}/bin/nixpkgs-fmt ${./.} + ${pkgs.golangci-lint}/bin/golangci-lint run --fix --timeout 10m + ${pkgs.nodePackages.prettier}/bin/prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}' + ${pkgs.golines}/bin/golines --max-len=88 --base-formatter=gofumpt -w ${./.} + ${pkgs.clang-tools}/bin/clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i ${./.} ''; - }; - - # `nix build` - packages = with pkgs; { - inherit headscale; - inherit headscale-docker; - }; - defaultPackage = pkgs.headscale; - - # `nix run` - apps.headscale = flake-utils.lib.mkApp { - drv = packages.headscale; - }; - apps.default = apps.headscale; - - checks = { - format = - pkgs.runCommand "check-format" - { - buildInputs = with pkgs; [ - gnumake - nixpkgs-fmt - golangci-lint - nodePackages.prettier - golines - clang-tools - ]; - } '' - ${pkgs.nixpkgs-fmt}/bin/nixpkgs-fmt ${./.} - ${pkgs.golangci-lint}/bin/golangci-lint run --fix --timeout 10m - ${pkgs.nodePackages.prettier}/bin/prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}' - ${pkgs.golines}/bin/golines --max-len=88 --base-formatter=gofumpt -w ${./.} - ${pkgs.clang-tools}/bin/clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i ${./.} - ''; - }; - }); + }; + }); } diff --git a/hscontrol/api_common.go b/hscontrol/api_common.go deleted file mode 100644 index 4d40c1d..0000000 --- a/hscontrol/api_common.go +++ /dev/null @@ -1,115 +0,0 @@ -package hscontrol - -import ( - "time" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" - "tailscale.com/tailcfg" -) - -func (h *Headscale) generateMapResponse( - mapRequest tailcfg.MapRequest, - machine *types.Machine, -) (*tailcfg.MapResponse, error) { - log.Trace(). - Str("func", "generateMapResponse"). - Str("machine", mapRequest.Hostinfo.Hostname). - Msg("Creating Map response") - node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig) - if err != nil { - log.Error(). - Caller(). - Str("func", "generateMapResponse"). - Err(err). - Msg("Cannot convert to node") - - return nil, err - } - - peers, err := h.db.GetValidPeers(h.aclRules, machine) - if err != nil { - log.Error(). - Caller(). - Str("func", "generateMapResponse"). - Err(err). - Msg("Cannot fetch peers") - - return nil, err - } - - profiles := h.db.GetMapResponseUserProfiles(*machine, peers) - - nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig) - if err != nil { - log.Error(). - Caller(). - Str("func", "generateMapResponse"). - Err(err). - Msg("Failed to convert peers to Tailscale nodes") - - return nil, err - } - - dnsConfig := getMapResponseDNSConfig( - h.cfg.DNSConfig, - h.cfg.BaseDomain, - *machine, - peers, - ) - - now := time.Now() - - resp := tailcfg.MapResponse{ - KeepAlive: false, - Node: node, - - // TODO: Only send if updated - DERPMap: h.DERPMap, - - // TODO: Only send if updated - Peers: nodePeers, - - // TODO(kradalby): Implement: - // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374 - // PeersChanged - // PeersRemoved - // PeersChangedPatch - // PeerSeenChange - // OnlineChange - - // TODO: Only send if updated - DNSConfig: dnsConfig, - - // TODO: Only send if updated - Domain: h.cfg.BaseDomain, - - // Do not instruct clients to collect services, we do not - // support or do anything with them - CollectServices: "false", - - // TODO: Only send if updated - PacketFilter: h.aclRules, - - UserProfiles: profiles, - - // TODO: Only send if updated - SSHPolicy: h.sshPolicy, - - ControlTime: &now, - - Debug: &tailcfg.Debug{ - DisableLogTail: !h.cfg.LogTail.Enabled, - RandomizeClientPort: h.cfg.RandomizeClientPort, - }, - } - - log.Trace(). - Str("func", "generateMapResponse"). - Str("machine", mapRequest.Hostinfo.Hostname). - // Interface("payload", resp). - Msgf("Generated map response: %s", util.TailMapResponseToString(resp)) - - return &resp, nil -} diff --git a/hscontrol/app.go b/hscontrol/app.go index bb68ced..0588037 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -80,7 +80,7 @@ type Headscale struct { dbString string dbType string dbDebug bool - privateKey *key.MachinePrivate + privateKey2019 *key.MachinePrivate noisePrivateKey *key.MachinePrivate DERPMap *tailcfg.DERPMap @@ -166,7 +166,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) { cfg: cfg, dbType: cfg.DBtype, dbString: dbString, - privateKey: privateKey, + privateKey2019: privateKey, noisePrivateKey: noisePrivateKey, aclRules: tailcfg.FilterAllowAll, // default allowall registrationCache: registrationCache, diff --git a/hscontrol/db/machine.go b/hscontrol/db/machine.go index a8d3569..9bfe581 100644 --- a/hscontrol/db/machine.go +++ b/hscontrol/db/machine.go @@ -39,15 +39,7 @@ var ( ) ) -// filterMachinesByACL wrapper function to not have devs pass around locks and maps -// related to the application outside of tests. -func (hsdb *HSDatabase) filterMachinesByACL( - aclRules []tailcfg.FilterRule, - currentMachine *types.Machine, peers types.Machines, -) types.Machines { - return policy.FilterMachinesByACL(currentMachine, peers, aclRules) -} - +// ListPeers returns all peers of machine, regardless of any Policy. func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { log.Trace(). Caller(). @@ -72,67 +64,6 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error return machines, nil } -func (hsdb *HSDatabase) getPeers( - aclRules []tailcfg.FilterRule, - machine *types.Machine, -) (types.Machines, error) { - var peers types.Machines - var err error - - // If ACLs rules are defined, filter visible host list with the ACLs - // else use the classic user scope - if len(aclRules) > 0 { - var machines []types.Machine - machines, err = hsdb.ListMachines() - if err != nil { - log.Error().Err(err).Msg("Error retrieving list of machines") - - return types.Machines{}, err - } - peers = hsdb.filterMachinesByACL(aclRules, machine, machines) - } else { - peers, err = hsdb.ListPeers(machine) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot fetch peers") - - return types.Machines{}, err - } - } - - sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) - - log.Trace(). - Caller(). - Str("self", machine.Hostname). - Str("peers", peers.String()). - Msg("Peers returned to caller") - - return peers, nil -} - -func (hsdb *HSDatabase) GetValidPeers( - aclRules []tailcfg.FilterRule, - machine *types.Machine, -) (types.Machines, error) { - validPeers := make(types.Machines, 0) - - peers, err := hsdb.getPeers(aclRules, machine) - if err != nil { - return types.Machines{}, err - } - - for _, peer := range peers { - if !peer.IsExpired() { - validPeers = append(validPeers, peer) - } - } - - return validPeers, nil -} - func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { machines := []types.Machine{} if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil { diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go index f34f64d..7ca96d5 100644 --- a/hscontrol/db/machine_test.go +++ b/hscontrol/db/machine_test.go @@ -293,8 +293,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false) c.Assert(err, check.IsNil) - peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines) - peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines) + peersOfTestMachine := policy.FilterMachinesByACL(testMachine, machines, aclRules) + peersOfAdminMachine := policy.FilterMachinesByACL(adminMachine, machines, aclRules) c.Log(peersOfTestMachine) c.Assert(len(peersOfTestMachine), check.Equals, 9) diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index e0ffd19..ce18675 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -2,13 +2,11 @@ package db import ( "errors" - "fmt" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "gorm.io/gorm" - "tailscale.com/tailcfg" ) var ( @@ -163,32 +161,3 @@ func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) return nil } - -func (hsdb *HSDatabase) GetMapResponseUserProfiles( - machine types.Machine, - peers types.Machines, -) []tailcfg.UserProfile { - userMap := make(map[string]types.User) - userMap[machine.User.Name] = machine.User - for _, peer := range peers { - userMap[peer.User.Name] = peer.User // not worth checking if already is there - } - - profiles := []tailcfg.UserProfile{} - for _, user := range userMap { - displayName := user.Name - - if hsdb.baseDomain != "" { - displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain) - } - - profiles = append(profiles, - tailcfg.UserProfile{ - ID: tailcfg.UserID(user.ID), - LoginName: user.Name, - DisplayName: displayName, - }) - } - - return profiles -} diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 02c0a2a..bc468b2 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -1,13 +1,10 @@ package db import ( - "net/netip" - "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "gorm.io/gorm" - "tailscale.com/tailcfg" ) func (s *Suite) TestCreateAndDestroyUser(c *check.C) { @@ -94,151 +91,6 @@ func (s *Suite) TestRenameUser(c *check.C) { c.Assert(err, check.Equals, ErrUserExists) } -func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - userShared1, err := db.CreateUser("shared1") - c.Assert(err, check.IsNil) - - userShared2, err := db.CreateUser("shared2") - c.Assert(err, check.IsNil) - - userShared3, err := db.CreateUser("shared3") - c.Assert(err, check.IsNil) - - preAuthKeyShared1, err := db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyShared2, err := db.CreatePreAuthKey( - userShared2.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyShared3, err := db.CreatePreAuthKey( - userShared3.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKey2Shared1, err := db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") - c.Assert(err, check.NotNil) - - machineInShared1 := &types.Machine{ - ID: 1, - MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - Hostname: "test_get_shared_nodes_1", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - AuthKeyID: uint(preAuthKeyShared1.ID), - } - db.db.Save(machineInShared1) - - _, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname) - c.Assert(err, check.IsNil) - - machineInShared2 := &types.Machine{ - ID: 2, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_2", - UserID: userShared2.ID, - User: *userShared2, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - AuthKeyID: uint(preAuthKeyShared2.ID), - } - db.db.Save(machineInShared2) - - _, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname) - c.Assert(err, check.IsNil) - - machineInShared3 := &types.Machine{ - ID: 3, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_3", - UserID: userShared3.ID, - User: *userShared3, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - AuthKeyID: uint(preAuthKeyShared3.ID), - } - db.db.Save(machineInShared3) - - _, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname) - c.Assert(err, check.IsNil) - - machine2InShared1 := &types.Machine{ - ID: 4, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_4", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - AuthKeyID: uint(preAuthKey2Shared1.ID), - } - db.db.Save(machine2InShared1) - - peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1) - c.Assert(err, check.IsNil) - - userProfiles := db.GetMapResponseUserProfiles( - *machineInShared1, - peersOfMachine1InShared1, - ) - - c.Assert(len(userProfiles), check.Equals, 3) - - found := false - for _, userProfiles := range userProfiles { - if userProfiles.DisplayName == userShared1.Name { - found = true - - break - } - } - c.Assert(found, check.Equals, true) - - found = false - for _, userProfile := range userProfiles { - if userProfile.DisplayName == userShared2.Name { - found = true - - break - } - } - c.Assert(found, check.Equals, true) -} - func (s *Suite) TestSetMachineUser(c *check.C) { oldUser, err := db.CreateUser("old") c.Assert(err, check.IsNil) diff --git a/hscontrol/derp_server.go b/hscontrol/derp_server.go index 9ca6eee..f178802 100644 --- a/hscontrol/derp_server.go +++ b/hscontrol/derp_server.go @@ -32,7 +32,7 @@ type DERPServer struct { func (h *Headscale) NewDERPServer() (*DERPServer, error) { log.Trace().Caller().Msg("Creating new embedded DERP server") - server := derp.NewServer(key.NodePrivate(*h.privateKey), log.Info().Msgf) + server := derp.NewServer(key.NodePrivate(*h.privateKey2019), log.Info().Msgf) region, err := h.generateRegionLocalDERP() if err != nil { return nil, err @@ -156,7 +156,7 @@ func (h *Headscale) DERPHandler( log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr) if !fastStart { - pubKey := h.privateKey.Public() + pubKey := h.privateKey2019.Public() pubKeyStr, _ := pubKey.MarshalText() //nolint fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+ "Upgrade: DERP\r\n"+ diff --git a/hscontrol/dns.go b/hscontrol/dns.go index 2c611f1..dcab04d 100644 --- a/hscontrol/dns.go +++ b/hscontrol/dns.go @@ -3,14 +3,9 @@ package hscontrol import ( "fmt" "net/netip" - "net/url" "strings" - mapset "github.com/deckarep/golang-set/v2" - "github.com/juanfont/headscale/hscontrol/types" "go4.org/netipx" - "tailscale.com/tailcfg" - "tailscale.com/types/dnstype" "tailscale.com/util/dnsname" ) @@ -23,10 +18,6 @@ const ( ipv6AddressLength = 128 ) -const ( - nextDNSDoHPrefix = "https://dns.nextdns.io" -) - // generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`. // This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS // server (listening in 100.100.100.100 udp/53) should be used for. @@ -158,63 +149,3 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { return fqdns } - -// If any nextdns DoH resolvers are present in the list of resolvers it will -// take metadata from the machine metadata and instruct tailscale to add it -// to the requests. This makes it possible to identify from which device the -// requests come in the NextDNS dashboard. -// -// This will produce a resolver like: -// `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { - for _, resolver := range resolvers { - if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { - attrs := url.Values{ - "device_name": []string{machine.Hostname}, - "device_model": []string{machine.HostInfo.OS}, - } - - if len(machine.IPAddresses) > 0 { - attrs.Add("device_ip", machine.IPAddresses[0].String()) - } - - resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode()) - } - } -} - -func getMapResponseDNSConfig( - dnsConfigOrig *tailcfg.DNSConfig, - baseDomain string, - machine types.Machine, - peers types.Machines, -) *tailcfg.DNSConfig { - var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone() - if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled - // Only inject the Search Domain of the current user - shared nodes should use their full FQDN - dnsConfig.Domains = append( - dnsConfig.Domains, - fmt.Sprintf( - "%s.%s", - machine.User.Name, - baseDomain, - ), - ) - - userSet := mapset.NewSet[types.User]() - userSet.Add(machine.User) - for _, p := range peers { - userSet.Add(p.User) - } - for _, user := range userSet.ToSlice() { - dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain) - dnsConfig.Routes[dnsRoute] = nil - } - } else { - dnsConfig = dnsConfigOrig - } - - addNextDNSMetadata(dnsConfig.Resolvers, machine) - - return dnsConfig -} diff --git a/hscontrol/dns_test.go b/hscontrol/dns_test.go index 6bee0ea..aae243c 100644 --- a/hscontrol/dns_test.go +++ b/hscontrol/dns_test.go @@ -1,14 +1,9 @@ package hscontrol import ( - "fmt" "net/netip" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" - "tailscale.com/tailcfg" - "tailscale.com/types/dnstype" ) func (s *Suite) TestMagicDNSRootDomains100(c *check.C) { @@ -112,293 +107,3 @@ func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) { c.Assert(yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true) c.Assert(yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true) } - -func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) { - userShared1, err := app.db.CreateUser("shared1") - c.Assert(err, check.IsNil) - - userShared2, err := app.db.CreateUser("shared2") - c.Assert(err, check.IsNil) - - userShared3, err := app.db.CreateUser("shared3") - c.Assert(err, check.IsNil) - - preAuthKeyInShared1, err := app.db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyInShared2, err := app.db.CreatePreAuthKey( - userShared2.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyInShared3, err := app.db.CreatePreAuthKey( - userShared3.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - PreAuthKey2InShared1, err := app.db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") - c.Assert(err, check.NotNil) - - machineInShared1 := &types.Machine{ - ID: 1, - MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - Hostname: "test_get_shared_nodes_1", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - AuthKeyID: uint(preAuthKeyInShared1.ID), - } - err = app.db.MachineSave(machineInShared1) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) - c.Assert(err, check.IsNil) - - machineInShared2 := &types.Machine{ - ID: 2, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_2", - UserID: userShared2.ID, - User: *userShared2, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - AuthKeyID: uint(preAuthKeyInShared2.ID), - } - err = app.db.MachineSave(machineInShared2) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) - c.Assert(err, check.IsNil) - - machineInShared3 := &types.Machine{ - ID: 3, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_3", - UserID: userShared3.ID, - User: *userShared3, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - AuthKeyID: uint(preAuthKeyInShared3.ID), - } - err = app.db.MachineSave(machineInShared3) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) - c.Assert(err, check.IsNil) - - machine2InShared1 := &types.Machine{ - ID: 4, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_4", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - AuthKeyID: uint(PreAuthKey2InShared1.ID), - } - err = app.db.MachineSave(machine2InShared1) - c.Assert(err, check.IsNil) - - baseDomain := "foobar.headscale.net" - dnsConfigOrig := tailcfg.DNSConfig{ - Routes: make(map[string][]*dnstype.Resolver), - Domains: []string{baseDomain}, - Proxied: true, - } - - peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) - c.Assert(err, check.IsNil) - - dnsConfig := getMapResponseDNSConfig( - &dnsConfigOrig, - baseDomain, - *machineInShared1, - peersOfMachineInShared1, - ) - c.Assert(dnsConfig, check.NotNil) - - c.Assert(len(dnsConfig.Routes), check.Equals, 3) - - domainRouteShared1 := fmt.Sprintf("%s.%s", userShared1.Name, baseDomain) - _, ok := dnsConfig.Routes[domainRouteShared1] - c.Assert(ok, check.Equals, true) - - domainRouteShared2 := fmt.Sprintf("%s.%s", userShared2.Name, baseDomain) - _, ok = dnsConfig.Routes[domainRouteShared2] - c.Assert(ok, check.Equals, true) - - domainRouteShared3 := fmt.Sprintf("%s.%s", userShared3.Name, baseDomain) - _, ok = dnsConfig.Routes[domainRouteShared3] - c.Assert(ok, check.Equals, true) -} - -func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) { - userShared1, err := app.db.CreateUser("shared1") - c.Assert(err, check.IsNil) - - userShared2, err := app.db.CreateUser("shared2") - c.Assert(err, check.IsNil) - - userShared3, err := app.db.CreateUser("shared3") - c.Assert(err, check.IsNil) - - preAuthKeyInShared1, err := app.db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyInShared2, err := app.db.CreatePreAuthKey( - userShared2.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKeyInShared3, err := app.db.CreatePreAuthKey( - userShared3.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - preAuthKey2InShared1, err := app.db.CreatePreAuthKey( - userShared1.Name, - false, - false, - nil, - nil, - ) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1") - c.Assert(err, check.NotNil) - - machineInShared1 := &types.Machine{ - ID: 1, - MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", - Hostname: "test_get_shared_nodes_1", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - AuthKeyID: uint(preAuthKeyInShared1.ID), - } - err = app.db.MachineSave(machineInShared1) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname) - c.Assert(err, check.IsNil) - - machineInShared2 := &types.Machine{ - ID: 2, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_2", - UserID: userShared2.ID, - User: *userShared2, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - AuthKeyID: uint(preAuthKeyInShared2.ID), - } - err = app.db.MachineSave(machineInShared2) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname) - c.Assert(err, check.IsNil) - - machineInShared3 := &types.Machine{ - ID: 3, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_3", - UserID: userShared3.ID, - User: *userShared3, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - AuthKeyID: uint(preAuthKeyInShared3.ID), - } - err = app.db.MachineSave(machineInShared3) - c.Assert(err, check.IsNil) - - _, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname) - c.Assert(err, check.IsNil) - - machine2InShared1 := &types.Machine{ - ID: 4, - MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", - Hostname: "test_get_shared_nodes_4", - UserID: userShared1.ID, - User: *userShared1, - RegisterMethod: util.RegisterMethodAuthKey, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - AuthKeyID: uint(preAuthKey2InShared1.ID), - } - err = app.db.MachineSave(machine2InShared1) - c.Assert(err, check.IsNil) - - baseDomain := "foobar.headscale.net" - dnsConfigOrig := tailcfg.DNSConfig{ - Routes: make(map[string][]*dnstype.Resolver), - Domains: []string{baseDomain}, - Proxied: false, - } - - peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1) - c.Assert(err, check.IsNil) - - dnsConfig := getMapResponseDNSConfig( - &dnsConfigOrig, - baseDomain, - *machineInShared1, - peersOfMachine1Shared1, - ) - c.Assert(dnsConfig, check.NotNil) - c.Assert(len(dnsConfig.Routes), check.Equals, 0) - c.Assert(len(dnsConfig.Domains), check.Equals, 1) -} diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go new file mode 100644 index 0000000..5dfa949 --- /dev/null +++ b/hscontrol/mapper/mapper.go @@ -0,0 +1,418 @@ +package mapper + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "net/url" + "strings" + "sync" + "time" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/klauspost/compress/zstd" + "github.com/rs/zerolog/log" + "tailscale.com/smallzstd" + "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" + "tailscale.com/types/key" +) + +const ( + nextDNSDoHPrefix = "https://dns.nextdns.io" + reservedResponseHeaderSize = 4 +) + +type Mapper struct { + db *db.HSDatabase + + privateKey2019 *key.MachinePrivate + isNoise bool + + // Configuration + // TODO(kradalby): figure out if this is the format we want this in + derpMap *tailcfg.DERPMap + baseDomain string + dnsCfg *tailcfg.DNSConfig + logtail bool + randomClientPort bool + stripEmailDomain bool +} + +func NewMapper( + db *db.HSDatabase, + privateKey *key.MachinePrivate, + isNoise bool, + derpMap *tailcfg.DERPMap, + baseDomain string, + dnsCfg *tailcfg.DNSConfig, + logtail bool, + randomClientPort bool, + stripEmailDomain bool, +) *Mapper { + return &Mapper{ + db: db, + + privateKey2019: privateKey, + isNoise: isNoise, + + derpMap: derpMap, + baseDomain: baseDomain, + dnsCfg: dnsCfg, + logtail: logtail, + randomClientPort: randomClientPort, + stripEmailDomain: stripEmailDomain, + } +} + +func (m Mapper) fullMapResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + pol *policy.ACLPolicy, +) (*tailcfg.MapResponse, error) { + log.Trace(). + Caller(). + Str("machine", mapRequest.Hostinfo.Hostname). + Msg("Creating Map response") + + // TODO(kradalby): Decouple this from DB? + node, err := m.db.TailNode(*machine, pol, m.dnsCfg) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot convert to node") + + return nil, err + } + + peers, err := m.db.ListPeers(machine) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot fetch peers") + + return nil, err + } + + rules, sshPolicy, err := policy.GenerateFilterRules(pol, peers, m.stripEmailDomain) + if err != nil { + return nil, err + } + + if len(rules) > 0 { + peers = policy.FilterMachinesByACL(machine, peers, rules) + } + + profiles := generateUserProfiles(machine, peers, m.baseDomain) + + // TODO(kradalby): Decouple this from DB? + nodePeers, err := m.db.TailNodes(peers, pol, m.dnsCfg) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to convert peers to Tailscale nodes") + + return nil, err + } + + // TODO(kradalby): Shold this mutation happen before TailNode(s) is called? + dnsConfig := generateDNSConfig( + m.dnsCfg, + m.baseDomain, + *machine, + peers, + ) + + now := time.Now() + + resp := tailcfg.MapResponse{ + KeepAlive: false, + Node: node, + + // TODO: Only send if updated + DERPMap: m.derpMap, + + // TODO: Only send if updated + Peers: nodePeers, + + // TODO(kradalby): Implement: + // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374 + // PeersChanged + // PeersRemoved + // PeersChangedPatch + // PeerSeenChange + // OnlineChange + + // TODO: Only send if updated + DNSConfig: dnsConfig, + + // TODO: Only send if updated + Domain: m.baseDomain, + + // Do not instruct clients to collect services, we do not + // support or do anything with them + CollectServices: "false", + + // TODO: Only send if updated + PacketFilter: rules, + + UserProfiles: profiles, + + // TODO: Only send if updated + SSHPolicy: sshPolicy, + + ControlTime: &now, + + Debug: &tailcfg.Debug{ + DisableLogTail: !m.logtail, + RandomizeClientPort: m.randomClientPort, + }, + } + + log.Trace(). + Caller(). + Str("machine", mapRequest.Hostinfo.Hostname). + // Interface("payload", resp). + Msgf("Generated map response: %s", util.TailMapResponseToString(resp)) + + return &resp, nil +} + +func generateUserProfiles( + machine *types.Machine, + peers types.Machines, + baseDomain string, +) []tailcfg.UserProfile { + userMap := make(map[string]types.User) + userMap[machine.User.Name] = machine.User + for _, peer := range peers { + userMap[peer.User.Name] = peer.User // not worth checking if already is there + } + + profiles := []tailcfg.UserProfile{} + for _, user := range userMap { + displayName := user.Name + + if baseDomain != "" { + displayName = fmt.Sprintf("%s@%s", user.Name, baseDomain) + } + + profiles = append(profiles, + tailcfg.UserProfile{ + ID: tailcfg.UserID(user.ID), + LoginName: user.Name, + DisplayName: displayName, + }) + } + + return profiles +} + +func generateDNSConfig( + base *tailcfg.DNSConfig, + baseDomain string, + machine types.Machine, + peers types.Machines, +) *tailcfg.DNSConfig { + dnsConfig := base.Clone() + + // if MagicDNS is enabled + if base != nil && base.Proxied { + // Only inject the Search Domain of the current user + // shared nodes should use their full FQDN + dnsConfig.Domains = append( + dnsConfig.Domains, + fmt.Sprintf( + "%s.%s", + machine.User.Name, + baseDomain, + ), + ) + + userSet := mapset.NewSet[types.User]() + userSet.Add(machine.User) + for _, p := range peers { + userSet.Add(p.User) + } + for _, user := range userSet.ToSlice() { + dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain) + dnsConfig.Routes[dnsRoute] = nil + } + } else { + dnsConfig = base + } + + addNextDNSMetadata(dnsConfig.Resolvers, machine) + + return dnsConfig +} + +// If any nextdns DoH resolvers are present in the list of resolvers it will +// take metadata from the machine metadata and instruct tailscale to add it +// to the requests. This makes it possible to identify from which device the +// requests come in the NextDNS dashboard. +// +// This will produce a resolver like: +// `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` +func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) { + for _, resolver := range resolvers { + if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { + attrs := url.Values{ + "device_name": []string{machine.Hostname}, + "device_model": []string{machine.HostInfo.OS}, + } + + if len(machine.IPAddresses) > 0 { + attrs.Add("device_ip", machine.IPAddresses[0].String()) + } + + resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode()) + } + } +} + +func (m Mapper) CreateMapResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, + pol *policy.ACLPolicy, +) ([]byte, error) { + mapResponse, err := m.fullMapResponse(mapRequest, machine, pol) + if err != nil { + return nil, err + } + + if m.isNoise { + return m.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress) + } + + var machineKey key.MachinePublic + err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot parse client key") + + return nil, err + } + + return m.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress) +} + +func (m Mapper) CreateKeepAliveResponse( + mapRequest tailcfg.MapRequest, + machine *types.Machine, +) ([]byte, error) { + keepAliveResponse := tailcfg.MapResponse{ + KeepAlive: true, + } + + if m.isNoise { + return m.marshalMapResponse( + keepAliveResponse, + key.MachinePublic{}, + mapRequest.Compress, + ) + } + + var machineKey key.MachinePublic + err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot parse client key") + + return nil, err + } + + return m.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress) +} + +func MarshalResponse( + resp interface{}, + privateKey2019 *key.MachinePrivate, + machineKey key.MachinePublic, +) ([]byte, error) { + jsonBody, err := json.Marshal(resp) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot marshal response") + + return nil, err + } + + if privateKey2019 != nil { + return privateKey2019.SealTo(machineKey, jsonBody), nil + } + + return jsonBody, nil +} + +func (m Mapper) marshalMapResponse( + resp interface{}, + machineKey key.MachinePublic, + compression string, +) ([]byte, error) { + jsonBody, err := json.Marshal(resp) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot marshal map response") + } + + var respBody []byte + if compression == util.ZstdCompression { + respBody = zstdEncode(jsonBody) + if !m.isNoise { // if legacy protocol + respBody = m.privateKey2019.SealTo(machineKey, respBody) + } + } else { + if !m.isNoise { // if legacy protocol + respBody = m.privateKey2019.SealTo(machineKey, jsonBody) + } else { + respBody = jsonBody + } + } + + data := make([]byte, reservedResponseHeaderSize) + binary.LittleEndian.PutUint32(data, uint32(len(respBody))) + data = append(data, respBody...) + + return data, nil +} + +func zstdEncode(in []byte) []byte { + encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) + if !ok { + panic("invalid type in sync pool") + } + out := encoder.EncodeAll(in, nil) + _ = encoder.Close() + zstdEncoderPool.Put(encoder) + + return out +} + +var zstdEncoderPool = &sync.Pool{ + New: func() any { + encoder, err := smallzstd.NewEncoder( + nil, + zstd.WithEncoderLevel(zstd.SpeedFastest)) + if err != nil { + panic(err) + } + + return encoder + }, +} diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go new file mode 100644 index 0000000..a5d65c9 --- /dev/null +++ b/hscontrol/mapper/mapper_test.go @@ -0,0 +1,131 @@ +package mapper + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/types" + "gopkg.in/check.v1" + "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" +) + +func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { + mach := func(hostname, username string, userid uint) types.Machine { + return types.Machine{ + Hostname: hostname, + UserID: userid, + User: types.User{ + Name: username, + }, + } + } + + machineInShared1 := mach("test_get_shared_nodes_1", "user1", 1) + machineInShared2 := mach("test_get_shared_nodes_2", "user2", 2) + machineInShared3 := mach("test_get_shared_nodes_3", "user3", 3) + machine2InShared1 := mach("test_get_shared_nodes_4", "user1", 1) + + userProfiles := generateUserProfiles( + &machineInShared1, + types.Machines{ + machineInShared2, machineInShared3, machine2InShared1, + }, + "", + ) + + c.Assert(len(userProfiles), check.Equals, 3) + + users := []string{ + "user1", "user2", "user3", + } + + for _, user := range users { + found := false + for _, userProfile := range userProfiles { + if userProfile.DisplayName == user { + found = true + + break + } + } + c.Assert(found, check.Equals, true) + } +} + +func TestDNSConfigMapResponse(t *testing.T) { + tests := []struct { + magicDNS bool + want *tailcfg.DNSConfig + }{ + { + magicDNS: true, + want: &tailcfg.DNSConfig{ + Routes: map[string][]*dnstype.Resolver{ + "shared1.foobar.headscale.net": {}, + "shared2.foobar.headscale.net": {}, + "shared3.foobar.headscale.net": {}, + }, + Domains: []string{ + "foobar.headscale.net", + "shared1.foobar.headscale.net", + }, + Proxied: true, + }, + }, + { + magicDNS: false, + want: &tailcfg.DNSConfig{ + Domains: []string{"foobar.headscale.net"}, + Proxied: false, + }, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("with-magicdns-%v", tt.magicDNS), func(t *testing.T) { + mach := func(hostname, username string, userid uint) types.Machine { + return types.Machine{ + Hostname: hostname, + UserID: userid, + User: types.User{ + Name: username, + }, + } + } + + baseDomain := "foobar.headscale.net" + + dnsConfigOrig := tailcfg.DNSConfig{ + Routes: make(map[string][]*dnstype.Resolver), + Domains: []string{baseDomain}, + Proxied: tt.magicDNS, + } + + machineInShared1 := mach("test_get_shared_nodes_1", "shared1", 1) + machineInShared2 := mach("test_get_shared_nodes_2", "shared2", 2) + machineInShared3 := mach("test_get_shared_nodes_3", "shared3", 3) + machine2InShared1 := mach("test_get_shared_nodes_4", "shared1", 1) + + peersOfMachineInShared1 := types.Machines{ + machineInShared1, + machineInShared2, + machineInShared3, + machine2InShared1, + } + + got := generateDNSConfig( + &dnsConfigOrig, + baseDomain, + machineInShared1, + peersOfMachineInShared1, + ) + + if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("expandAlias() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/hscontrol/mapper/suite_test.go b/hscontrol/mapper/suite_test.go new file mode 100644 index 0000000..c9b1a58 --- /dev/null +++ b/hscontrol/mapper/suite_test.go @@ -0,0 +1,15 @@ +package mapper + +import ( + "testing" + + "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + check.TestingT(t) +} + +var _ = check.Suite(&Suite{}) + +type Suite struct{} diff --git a/hscontrol/protocol_common.go b/hscontrol/protocol_common.go index ae034fb..c0ba924 100644 --- a/hscontrol/protocol_common.go +++ b/hscontrol/protocol_common.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -61,7 +62,7 @@ func (h *Headscale) KeyHandler( // TS2021 (Tailscale v2 protocol) requires to have a different key if clientCapabilityVersion >= NoiseCapabilityVersion { resp := tailcfg.OverTLSPublicKeyResponse{ - LegacyPublicKey: h.privateKey.Public(), + LegacyPublicKey: h.privateKey2019.Public(), PublicKey: h.noisePrivateKey.Public(), } writer.Header().Set("Content-Type", "application/json") @@ -84,7 +85,7 @@ func (h *Headscale) KeyHandler( // Old clients don't send a 'v' parameter, so we send the legacy public key writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey.Public()))) + _, err := writer.Write([]byte(util.MachinePublicKeyStripPrefix(h.privateKey2019.Public()))) if err != nil { log.Error(). Caller(). @@ -323,7 +324,7 @@ func (h *Headscale) handleAuthKeyCommon( Msg("Failed authentication via AuthKey") resp.MachineAuthorized = false - respBody, err := h.marshalResponse(resp, machineKey, isNoise) + respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey) if err != nil { log.Error(). Caller(). @@ -483,7 +484,7 @@ func (h *Headscale) handleAuthKeyCommon( // Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* resp.Login = *pak.User.TailscaleLogin() - respBody, err := h.marshalResponse(resp, machineKey, isNoise) + respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey) if err != nil { log.Error(). Caller(). @@ -548,7 +549,7 @@ func (h *Headscale) handleNewMachineCommon( registerRequest.NodeKey) } - respBody, err := h.marshalResponse(resp, machineKey, isNoise) + respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey) if err != nil { log.Error(). Caller(). @@ -609,7 +610,7 @@ func (h *Headscale) handleMachineLogOutCommon( resp.MachineAuthorized = false resp.NodeKeyExpired = true resp.User = *machine.User.TailscaleUser() - respBody, err := h.marshalResponse(resp, machineKey, isNoise) + respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey) if err != nil { log.Error(). Caller(). @@ -673,7 +674,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon( resp.User = *machine.User.TailscaleUser() resp.Login = *machine.User.TailscaleLogin() - respBody, err := h.marshalResponse(resp, machineKey, isNoise) + respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey) if err != nil { log.Error(). Caller(). @@ -735,7 +736,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon( resp.AuthURL = "" resp.User = *machine.User.TailscaleUser() - respBody, err := h.marshalResponse(resp, machineKey, isNoise) + respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey) if err != nil { log.Error(). Caller(). @@ -802,7 +803,7 @@ func (h *Headscale) handleMachineExpiredOrLoggedOutCommon( registerRequest.NodeKey) } - respBody, err := h.marshalResponse(resp, machineKey, isNoise) + respBody, err := mapper.MarshalResponse(resp, h.privateKey2019, machineKey) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/protocol_common_poll.go b/hscontrol/protocol_common_poll.go index 3d43238..27c5d82 100644 --- a/hscontrol/protocol_common_poll.go +++ b/hscontrol/protocol_common_poll.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -29,6 +30,19 @@ func (h *Headscale) handlePollCommon( mapRequest tailcfg.MapRequest, isNoise bool, ) { + // TODO(kradalby): This is a stepping stone, mapper should be initiated once + // per client or something similar + mapp := mapper.NewMapper(h.db, + h.privateKey2019, + isNoise, + h.DERPMap, + h.cfg.BaseDomain, + h.cfg.DNSConfig, + h.cfg.LogTail.Enabled, + h.cfg.RandomizeClientPort, + h.cfg.OIDC.StripEmaildomain, + ) + machine.Hostname = mapRequest.Hostinfo.Hostname machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo) machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) @@ -87,7 +101,7 @@ func (h *Headscale) handlePollCommon( return } - mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) + mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) if err != nil { log.Error(). Str("handler", "PollNetMap"). @@ -245,6 +259,19 @@ func (h *Headscale) pollNetMapStream( updateChan chan struct{}, isNoise bool, ) { + // TODO(kradalby): This is a stepping stone, mapper should be initiated once + // per client or something similar + mapp := mapper.NewMapper(h.db, + h.privateKey2019, + isNoise, + h.DERPMap, + h.cfg.BaseDomain, + h.cfg.DNSConfig, + h.cfg.LogTail.Enabled, + h.cfg.RandomizeClientPort, + h.cfg.OIDC.StripEmaildomain, + ) + h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() @@ -463,7 +490,7 @@ func (h *Headscale) pollNetMapStream( Time("last_successful_update", lastUpdate). Time("last_state_change", h.getLastStateChange(machine.User)). Msgf("There has been updates since the last successful update to %s", machine.Hostname) - data, err := h.getMapResponseData(mapRequest, machine, isNoise) + data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). @@ -623,6 +650,19 @@ func (h *Headscale) scheduledPollWorker( machine *types.Machine, isNoise bool, ) { + // TODO(kradalby): This is a stepping stone, mapper should be initiated once + // per client or something similar + mapp := mapper.NewMapper(h.db, + h.privateKey2019, + isNoise, + h.DERPMap, + h.cfg.BaseDomain, + h.cfg.DNSConfig, + h.cfg.LogTail.Enabled, + h.cfg.RandomizeClientPort, + h.cfg.OIDC.StripEmaildomain, + ) + keepAliveTicker := time.NewTicker(keepAliveInterval) updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval) @@ -643,7 +683,7 @@ func (h *Headscale) scheduledPollWorker( return case <-keepAliveTicker.C: - data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise) + data, err := mapp.CreateKeepAliveResponse(mapRequest, machine) if err != nil { log.Error(). Str("func", "keepAlive"). diff --git a/hscontrol/protocol_common_utils.go b/hscontrol/protocol_common_utils.go index 8990eeb..7b0b0ac 100644 --- a/hscontrol/protocol_common_utils.go +++ b/hscontrol/protocol_common_utils.go @@ -1,157 +1 @@ package hscontrol - -import ( - "encoding/binary" - "encoding/json" - "sync" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/klauspost/compress/zstd" - "github.com/rs/zerolog/log" - "tailscale.com/smallzstd" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func (h *Headscale) getMapResponseData( - mapRequest tailcfg.MapRequest, - machine *types.Machine, - isNoise bool, -) ([]byte, error) { - mapResponse, err := h.generateMapResponse(mapRequest, machine) - if err != nil { - return nil, err - } - - if isNoise { - return h.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress, isNoise) - } - - var machineKey key.MachinePublic - err = machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse client key") - - return nil, err - } - - return h.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress, isNoise) -} - -func (h *Headscale) getMapKeepAliveResponseData( - mapRequest tailcfg.MapRequest, - machine *types.Machine, - isNoise bool, -) ([]byte, error) { - keepAliveResponse := tailcfg.MapResponse{ - KeepAlive: true, - } - - if isNoise { - return h.marshalMapResponse( - keepAliveResponse, - key.MachinePublic{}, - mapRequest.Compress, - isNoise, - ) - } - - var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey))) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse client key") - - return nil, err - } - - return h.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress, isNoise) -} - -func (h *Headscale) marshalResponse( - resp interface{}, - machineKey key.MachinePublic, - isNoise bool, -) ([]byte, error) { - jsonBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot marshal response") - - return nil, err - } - - if isNoise { - return jsonBody, nil - } - - return h.privateKey.SealTo(machineKey, jsonBody), nil -} - -func (h *Headscale) marshalMapResponse( - resp interface{}, - machineKey key.MachinePublic, - compression string, - isNoise bool, -) ([]byte, error) { - jsonBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot marshal map response") - } - - var respBody []byte - if compression == util.ZstdCompression { - respBody = zstdEncode(jsonBody) - if !isNoise { // if legacy protocol - respBody = h.privateKey.SealTo(machineKey, respBody) - } - } else { - if !isNoise { // if legacy protocol - respBody = h.privateKey.SealTo(machineKey, jsonBody) - } else { - respBody = jsonBody - } - } - - data := make([]byte, reservedResponseHeaderSize) - binary.LittleEndian.PutUint32(data, uint32(len(respBody))) - data = append(data, respBody...) - - return data, nil -} - -func zstdEncode(in []byte) []byte { - encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) - if !ok { - panic("invalid type in sync pool") - } - out := encoder.EncodeAll(in, nil) - _ = encoder.Close() - zstdEncoderPool.Put(encoder) - - return out -} - -var zstdEncoderPool = &sync.Pool{ - New: func() any { - encoder, err := smallzstd.NewEncoder( - nil, - zstd.WithEncoderLevel(zstd.SpeedFastest)) - if err != nil { - panic(err) - } - - return encoder - }, -} diff --git a/hscontrol/protocol_legacy.go b/hscontrol/protocol_legacy.go index f443eba..06eb314 100644 --- a/hscontrol/protocol_legacy.go +++ b/hscontrol/protocol_legacy.go @@ -45,7 +45,7 @@ func (h *Headscale) RegistrationHandler( return } registerRequest := tailcfg.RegisterRequest{} - err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey) + err = util.DecodeAndUnmarshalNaCl(body, ®isterRequest, &machineKey, h.privateKey2019) if err != nil { log.Error(). Caller(). diff --git a/hscontrol/protocol_legacy_poll.go b/hscontrol/protocol_legacy_poll.go index 3755faf..27e38a9 100644 --- a/hscontrol/protocol_legacy_poll.go +++ b/hscontrol/protocol_legacy_poll.go @@ -57,7 +57,7 @@ func (h *Headscale) PollNetMapHandler( return } mapRequest := tailcfg.MapRequest{} - err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey) + err = util.DecodeAndUnmarshalNaCl(body, &mapRequest, &machineKey, h.privateKey2019) if err != nil { log.Error(). Str("handler", "PollNetMap").