diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ca4d4cf..963663a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,12 +18,11 @@ jobs: - name: Setup Go uses: actions/setup-go@v2 with: - go-version: "1.16.3" + go-version: "1.17.3" - name: Install dependencies run: | go version - go install golang.org/x/lint/golint@latest sudo apt update sudo apt install -y make diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6b561d2..b3c6400 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,20 +1,37 @@ +--- name: CI on: [push, pull_request] jobs: - # The "build" workflow - lint: - # The type of runner that the job will run on + golangci-lint: runs-on: ubuntu-latest - - # Steps represent a sequence of tasks that will be executed as part of the job steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - uses: actions/checkout@v2 - # Install and run golangci-lint as a separate step, it's much faster this - # way because this action has caching. It'll get run again in `make lint` - # below, but it's still much faster in the end than installing - # golangci-lint manually in the `Run lint` step. - - uses: golangci/golangci-lint-action@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + version: latest + + prettier-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Prettify code + uses: creyD/prettier_action@v4.0 + with: + prettier_options: >- + --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html} + only_changed: false + dry: true + + proto-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: bufbuild/buf-setup-action@v0.7.0 + - uses: bufbuild/buf-lint-action@v1 + with: + input: "proto" diff --git a/.golangci.yaml b/.golangci.yaml index a97c2bb..2defdc8 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,7 +1,53 @@ --- run: - timeout: 5m + timeout: 10m issues: skip-dirs: - gen +linters: + enable-all: true + disable: + - exhaustivestruct + - revive + - lll + - interfacer + - scopelint + - maligned + - golint + - gofmt + - gochecknoglobals + - gochecknoinits + - gocognit + - funlen + - exhaustivestruct + - tagliatelle + - godox + - ireturn + + # In progress + - gocritic + + # We should strive to enable these: + - wrapcheck + - dupl + - makezero + + # We might want to enable this, but it might be a lot of work + - cyclop + - nestif + - wsl # might be incompatible with gofumpt + - testpackage + - paralleltest + +linters-settings: + varnamelen: + ignore-type-assert-ok: true + ignore-map-index-ok: true + ignore-names: + - err + - db + - id + - ip + - ok + - c diff --git a/Makefile b/Makefile index 92beaef..060d3b9 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,14 @@ # Calculate version version = $(shell ./scripts/version-at-commit.sh) +rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d)) + +# GO_SOURCES = $(wildcard *.go) +# PROTO_SOURCES = $(wildcard **/*.proto) +GO_SOURCES = $(call rwildcard,,*.go) +PROTO_SOURCES = $(call rwildcard,,*.proto) + + build: go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$(version)" cmd/headscale/headscale.go @@ -19,7 +27,12 @@ coverprofile_html: go tool cover -html=coverage.out lint: - golangci-lint run --fix + golangci-lint run --fix --timeout 10m + +fmt: + prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}' + golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES) + clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i $(PROTO_SOURCES) proto-lint: cd proto/ && buf lint diff --git a/README.md b/README.md index e9f7990..41f388e 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,6 @@ Suggestions/PRs welcomed! Please have a look at the documentation under [`docs/`](docs/). - ## Disclaimer 1. We have nothing to do with Tailscale, or Tailscale Inc. @@ -64,13 +63,30 @@ Please have a look at the documentation under [`docs/`](docs/). To contribute to Headscale you would need the lastest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator). +### Code style + +To ensure we have some consistency with a growing number of contributes, this project has adopted linting and style/formatting rules: + +The **Go** code is linted with [`golangci-lint`](https://golangci-lint.run) and +formatted with [`golines`](https://github.com/segmentio/golines) (width 88) and +[`gofumpt`](https://github.com/mvdan/gofumpt). +Please configure your editor to run the tools while developing and make sure to +run `make lint` and `make fmt` before committing any code. + +The **Proto** code is linted with [`buf`](https://docs.buf.build/lint/overview) and +formatted with [`clang-format`](https://clang.llvm.org/docs/ClangFormat.html). + +The **rest** (markdown, yaml, etc) is formatted with [`prettier`](https://prettier.io). + +Check out the `.golangci.yaml` and `Makefile` to see the specific configuration. + ### Install development tools - Go - Buf - Protobuf tools: -```shell +```shell make install-protobuf-plugins ``` @@ -81,6 +97,7 @@ Some parts of the project requires the generation of Go code from Protobuf (if c ```shell make generate ``` + **Note**: Please check in changes from `gen/` in a separate commit to make it easier to review. To run the tests: @@ -261,5 +278,3 @@ make build - - diff --git a/acls.go b/acls.go index cb02e04..1550c34 100644 --- a/acls.go +++ b/acls.go @@ -9,23 +9,30 @@ import ( "strings" "github.com/rs/zerolog/log" - "github.com/tailscale/hujson" "inet.af/netaddr" "tailscale.com/tailcfg" ) const ( - errorEmptyPolicy = Error("empty policy") - errorInvalidAction = Error("invalid action") - errorInvalidUserSection = Error("invalid user section") - errorInvalidGroup = Error("invalid group") - errorInvalidTag = Error("invalid tag") - errorInvalidNamespace = Error("invalid namespace") - errorInvalidPortFormat = Error("invalid port format") + errEmptyPolicy = Error("empty policy") + errInvalidAction = Error("invalid action") + errInvalidUserSection = Error("invalid user section") + errInvalidGroup = Error("invalid group") + errInvalidTag = Error("invalid tag") + errInvalidNamespace = Error("invalid namespace") + errInvalidPortFormat = Error("invalid port format") ) -// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules +const ( + Base10 = 10 + BitSize16 = 16 + portRangeBegin = 0 + portRangeEnd = 65535 + expectedTokenItems = 2 +) + +// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules. func (h *Headscale) LoadACLPolicy(path string) error { policyFile, err := os.Open(path) if err != nil { @@ -34,23 +41,23 @@ func (h *Headscale) LoadACLPolicy(path string) error { defer policyFile.Close() var policy ACLPolicy - b, err := io.ReadAll(policyFile) + policyBytes, err := io.ReadAll(policyFile) if err != nil { return err } - ast, err := hujson.Parse(b) + ast, err := hujson.Parse(policyBytes) if err != nil { return err } ast.Standardize() - b = ast.Pack() - err = json.Unmarshal(b, &policy) + policyBytes = ast.Pack() + err = json.Unmarshal(policyBytes, &policy) if err != nil { return err } if policy.IsZero() { - return errorEmptyPolicy + return errEmptyPolicy } h.aclPolicy = &policy @@ -59,37 +66,40 @@ func (h *Headscale) LoadACLPolicy(path string) error { return err } h.aclRules = rules + return nil } func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) { rules := []tailcfg.FilterRule{} - for i, a := range h.aclPolicy.ACLs { - if a.Action != "accept" { - return nil, errorInvalidAction + for index, acl := range h.aclPolicy.ACLs { + if acl.Action != "accept" { + return nil, errInvalidAction } - r := tailcfg.FilterRule{} + filterRule := tailcfg.FilterRule{} srcIPs := []string{} - for j, u := range a.Users { - srcs, err := h.generateACLPolicySrcIP(u) + for innerIndex, user := range acl.Users { + srcs, err := h.generateACLPolicySrcIP(user) if err != nil { log.Error(). - Msgf("Error parsing ACL %d, User %d", i, j) + Msgf("Error parsing ACL %d, User %d", index, innerIndex) + return nil, err } srcIPs = append(srcIPs, srcs...) } - r.SrcIPs = srcIPs + filterRule.SrcIPs = srcIPs destPorts := []tailcfg.NetPortRange{} - for j, d := range a.Ports { - dests, err := h.generateACLPolicyDestPorts(d) + for innerIndex, ports := range acl.Ports { + dests, err := h.generateACLPolicyDestPorts(ports) if err != nil { log.Error(). - Msgf("Error parsing ACL %d, Port %d", i, j) + Msgf("Error parsing ACL %d, Port %d", index, innerIndex) + return nil, err } destPorts = append(destPorts, dests...) @@ -108,10 +118,12 @@ func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) { return h.expandAlias(u) } -func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange, error) { +func (h *Headscale) generateACLPolicyDestPorts( + d string, +) ([]tailcfg.NetPortRange, error) { tokens := strings.Split(d, ":") - if len(tokens) < 2 || len(tokens) > 3 { - return nil, errorInvalidPortFormat + if len(tokens) < expectedTokenItems || len(tokens) > 3 { + return nil, errInvalidPortFormat } var alias string @@ -121,7 +133,7 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange // tag:montreal-webserver:80,443 // tag:api-server:443 // example-host-1:* - if len(tokens) == 2 { + if len(tokens) == expectedTokenItems { alias = tokens[0] } else { alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) @@ -146,34 +158,36 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange dests = append(dests, pr) } } + return dests, nil } -func (h *Headscale) expandAlias(s string) ([]string, error) { - if s == "*" { +func (h *Headscale) expandAlias(alias string) ([]string, error) { + if alias == "*" { return []string{"*"}, nil } - if strings.HasPrefix(s, "group:") { - if _, ok := h.aclPolicy.Groups[s]; !ok { - return nil, errorInvalidGroup + if strings.HasPrefix(alias, "group:") { + if _, ok := h.aclPolicy.Groups[alias]; !ok { + return nil, errInvalidGroup } ips := []string{} - for _, n := range h.aclPolicy.Groups[s] { + for _, n := range h.aclPolicy.Groups[alias] { nodes, err := h.ListMachinesInNamespace(n) if err != nil { - return nil, errorInvalidNamespace + return nil, errInvalidNamespace } for _, node := range nodes { ips = append(ips, node.IPAddress) } } + return ips, nil } - if strings.HasPrefix(s, "tag:") { - if _, ok := h.aclPolicy.TagOwners[s]; !ok { - return nil, errorInvalidTag + if strings.HasPrefix(alias, "tag:") { + if _, ok := h.aclPolicy.TagOwners[alias]; !ok { + return nil, errInvalidTag } // This will have HORRIBLE performance. @@ -183,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { return nil, err } ips := []string{} - for _, m := range machines { + for _, machine := range machines { hostinfo := tailcfg.Hostinfo{} - if len(m.HostInfo) != 0 { - hi, err := m.HostInfo.MarshalJSON() + if len(machine.HostInfo) != 0 { + hi, err := machine.HostInfo.MarshalJSON() if err != nil { return nil, err } @@ -197,17 +211,19 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { // FIXME: Check TagOwners allows this for _, t := range hostinfo.RequestTags { - if s[4:] == t { - ips = append(ips, m.IPAddress) + if alias[4:] == t { + ips = append(ips, machine.IPAddress) + break } } } } + return ips, nil } - n, err := h.GetNamespace(s) + n, err := h.GetNamespace(alias) if err == nil { nodes, err := h.ListMachinesInNamespace(n.Name) if err != nil { @@ -217,49 +233,54 @@ func (h *Headscale) expandAlias(s string) ([]string, error) { for _, n := range nodes { ips = append(ips, n.IPAddress) } + return ips, nil } - if h, ok := h.aclPolicy.Hosts[s]; ok { + if h, ok := h.aclPolicy.Hosts[alias]; ok { return []string{h.String()}, nil } - ip, err := netaddr.ParseIP(s) + ip, err := netaddr.ParseIP(alias) if err == nil { return []string{ip.String()}, nil } - cidr, err := netaddr.ParseIPPrefix(s) + cidr, err := netaddr.ParseIPPrefix(alias) if err == nil { return []string{cidr.String()}, nil } - return nil, errorInvalidUserSection + return nil, errInvalidUserSection } -func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) { - if s == "*" { - return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil +func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) { + if portsStr == "*" { + return &[]tailcfg.PortRange{ + {First: portRangeBegin, Last: portRangeEnd}, + }, nil } ports := []tailcfg.PortRange{} - for _, p := range strings.Split(s, ",") { - rang := strings.Split(p, "-") - if len(rang) == 1 { - pi, err := strconv.ParseUint(rang[0], 10, 16) + for _, portStr := range strings.Split(portsStr, ",") { + rang := strings.Split(portStr, "-") + switch len(rang) { + case 1: + port, err := strconv.ParseUint(rang[0], Base10, BitSize16) if err != nil { return nil, err } ports = append(ports, tailcfg.PortRange{ - First: uint16(pi), - Last: uint16(pi), + First: uint16(port), + Last: uint16(port), }) - } else if len(rang) == 2 { - start, err := strconv.ParseUint(rang[0], 10, 16) + + case expectedTokenItems: + start, err := strconv.ParseUint(rang[0], Base10, BitSize16) if err != nil { return nil, err } - last, err := strconv.ParseUint(rang[1], 10, 16) + last, err := strconv.ParseUint(rang[1], Base10, BitSize16) if err != nil { return nil, err } @@ -267,9 +288,11 @@ func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) { First: uint16(start), Last: uint16(last), }) - } else { - return nil, errorInvalidPortFormat + + default: + return nil, errInvalidPortFormat } } + return &ports, nil } diff --git a/acls_test.go b/acls_test.go index da7f3ec..3e051f5 100644 --- a/acls_test.go +++ b/acls_test.go @@ -5,54 +5,58 @@ import ( ) func (s *Suite) TestWrongPath(c *check.C) { - err := h.LoadACLPolicy("asdfg") + err := app.LoadACLPolicy("asdfg") c.Assert(err, check.NotNil) } func (s *Suite) TestBrokenHuJson(c *check.C) { - err := h.LoadACLPolicy("./tests/acls/broken.hujson") + err := app.LoadACLPolicy("./tests/acls/broken.hujson") c.Assert(err, check.NotNil) } func (s *Suite) TestInvalidPolicyHuson(c *check.C) { - err := h.LoadACLPolicy("./tests/acls/invalid.hujson") + err := app.LoadACLPolicy("./tests/acls/invalid.hujson") c.Assert(err, check.NotNil) - c.Assert(err, check.Equals, errorEmptyPolicy) + c.Assert(err, check.Equals, errEmptyPolicy) } func (s *Suite) TestParseHosts(c *check.C) { - var hs Hosts - err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`)) - c.Assert(hs, check.NotNil) + var hosts Hosts + err := hosts.UnmarshalJSON( + []byte( + `{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`, + ), + ) + c.Assert(hosts, check.NotNil) c.Assert(err, check.IsNil) } func (s *Suite) TestParseInvalidCIDR(c *check.C) { - var hs Hosts - err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`)) - c.Assert(hs, check.IsNil) + var hosts Hosts + err := hosts.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`)) + c.Assert(hosts, check.IsNil) c.Assert(err, check.NotNil) } func (s *Suite) TestRuleInvalidGeneration(c *check.C) { - err := h.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson") + err := app.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson") c.Assert(err, check.NotNil) } func (s *Suite) TestBasicRule(c *check.C) { - err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson") + err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson") c.Assert(err, check.IsNil) - rules, err := h.generateACLRules() + rules, err := app.generateACLRules() c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) } func (s *Suite) TestPortRange(c *check.C) { - err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson") + err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson") c.Assert(err, check.IsNil) - rules, err := h.generateACLRules() + rules, err := app.generateACLRules() c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -63,10 +67,10 @@ func (s *Suite) TestPortRange(c *check.C) { } func (s *Suite) TestPortWildcard(c *check.C) { - err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") + err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson") c.Assert(err, check.IsNil) - rules, err := h.generateACLRules() + rules, err := app.generateACLRules() c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -79,33 +83,35 @@ func (s *Suite) TestPortWildcard(c *check.C) { } func (s *Suite) TestPortNamespace(c *check.C) { - n, err := h.CreateNamespace("testnamespace") + namespace, err := app.CreateNamespace("testnamespace") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine("testnamespace", "testmachine") + _, err = app.GetMachine("testnamespace", "testmachine") c.Assert(err, check.NotNil) - ip, _ := h.getAvailableIP() - m := Machine{ + ip, _ := app.getAvailableIP() + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", IPAddress: ip.String(), AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) - err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_namespace_as_user.hujson") + err = app.LoadACLPolicy( + "./tests/acls/acl_policy_basic_namespace_as_user.hujson", + ) c.Assert(err, check.IsNil) - rules, err := h.generateACLRules() + rules, err := app.generateACLRules() c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) @@ -119,33 +125,33 @@ func (s *Suite) TestPortNamespace(c *check.C) { } func (s *Suite) TestPortGroup(c *check.C) { - n, err := h.CreateNamespace("testnamespace") + namespace, err := app.CreateNamespace("testnamespace") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine("testnamespace", "testmachine") + _, err = app.GetMachine("testnamespace", "testmachine") c.Assert(err, check.NotNil) - ip, _ := h.getAvailableIP() - m := Machine{ + ip, _ := app.getAvailableIP() + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", IPAddress: ip.String(), AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) - err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson") + err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson") c.Assert(err, check.IsNil) - rules, err := h.generateACLRules() + rules, err := app.generateACLRules() c.Assert(err, check.IsNil) c.Assert(rules, check.NotNil) diff --git a/acls_types.go b/acls_types.go index 67b74e7..08e650f 100644 --- a/acls_types.go +++ b/acls_types.go @@ -8,7 +8,7 @@ import ( "inet.af/netaddr" ) -// ACLPolicy represents a Tailscale ACL Policy +// ACLPolicy represents a Tailscale ACL Policy. type ACLPolicy struct { Groups Groups `json:"Groups"` Hosts Hosts `json:"Hosts"` @@ -17,61 +17,63 @@ type ACLPolicy struct { Tests []ACLTest `json:"Tests"` } -// ACL is a basic rule for the ACL Policy +// ACL is a basic rule for the ACL Policy. type ACL struct { Action string `json:"Action"` Users []string `json:"Users"` Ports []string `json:"Ports"` } -// Groups references a series of alias in the ACL rules +// Groups references a series of alias in the ACL rules. type Groups map[string][]string -// Hosts are alias for IP addresses or subnets +// Hosts are alias for IP addresses or subnets. type Hosts map[string]netaddr.IPPrefix -// TagOwners specify what users (namespaces?) are allow to use certain tags +// TagOwners specify what users (namespaces?) are allow to use certain tags. type TagOwners map[string][]string -// ACLTest is not implemented, but should be use to check if a certain rule is allowed +// ACLTest is not implemented, but should be use to check if a certain rule is allowed. type ACLTest struct { User string `json:"User"` Allow []string `json:"Allow"` Deny []string `json:"Deny,omitempty"` } -// UnmarshalJSON allows to parse the Hosts directly into netaddr objects -func (h *Hosts) UnmarshalJSON(data []byte) error { - hosts := Hosts{} - hs := make(map[string]string) +// UnmarshalJSON allows to parse the Hosts directly into netaddr objects. +func (hosts *Hosts) UnmarshalJSON(data []byte) error { + newHosts := Hosts{} + hostIPPrefixMap := make(map[string]string) ast, err := hujson.Parse(data) if err != nil { return err } ast.Standardize() data = ast.Pack() - err = json.Unmarshal(data, &hs) + err = json.Unmarshal(data, &hostIPPrefixMap) if err != nil { return err } - for k, v := range hs { - if !strings.Contains(v, "/") { - v = v + "/32" + for host, prefixStr := range hostIPPrefixMap { + if !strings.Contains(prefixStr, "/") { + prefixStr += "/32" } - prefix, err := netaddr.ParseIPPrefix(v) + prefix, err := netaddr.ParseIPPrefix(prefixStr) if err != nil { return err } - hosts[k] = prefix + newHosts[host] = prefix } - *h = hosts + *hosts = newHosts + return nil } -// IsZero is perhaps a bit naive here -func (p ACLPolicy) IsZero() bool { - if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 { +// IsZero is perhaps a bit naive here. +func (policy ACLPolicy) IsZero() bool { + if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 { return true } + return false } diff --git a/api.go b/api.go index 490ce25..85c28e3 100644 --- a/api.go +++ b/api.go @@ -10,31 +10,37 @@ import ( "strings" "time" - "github.com/rs/zerolog/log" - "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" + "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/wgkey" ) +const reservedResponseHeaderSize = 4 + // KeyHandler provides the Headscale pub key -// Listens in /key -func (h *Headscale) KeyHandler(c *gin.Context) { - c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString())) +// Listens in /key. +func (h *Headscale) KeyHandler(ctx *gin.Context) { + ctx.Data( + http.StatusOK, + "text/plain; charset=utf-8", + []byte(h.publicKey.HexString()), + ) } // RegisterWebAPI shows a simple message in the browser to point to the CLI -// Listens in /register -func (h *Headscale) RegisterWebAPI(c *gin.Context) { - mKeyStr := c.Query("key") - if mKeyStr == "" { - c.String(http.StatusBadRequest, "Wrong params") +// Listens in /register. +func (h *Headscale) RegisterWebAPI(ctx *gin.Context) { + machineKeyStr := ctx.Query("key") + if machineKeyStr == "" { + ctx.String(http.StatusBadRequest, "Wrong params") + return } - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`

headscale

@@ -51,43 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) { - `, mKeyStr))) + `, machineKeyStr))) } // RegistrationHandler handles the actual registration process of a machine -// Endpoint /machine/:id -func (h *Headscale) RegistrationHandler(c *gin.Context) { - body, _ := io.ReadAll(c.Request.Body) - mKeyStr := c.Param("id") - mKey, err := wgkey.ParseHex(mKeyStr) +// Endpoint /machine/:id. +func (h *Headscale) RegistrationHandler(ctx *gin.Context) { + body, _ := io.ReadAll(ctx.Request.Body) + machineKeyStr := ctx.Param("id") + machineKey, err := wgkey.ParseHex(machineKeyStr) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot parse machine key") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - c.String(http.StatusInternalServerError, "Sad!") + ctx.String(http.StatusInternalServerError, "Sad!") + return } req := tailcfg.RegisterRequest{} - err = decode(body, &req, &mKey, h.privateKey) + err = decode(body, &req, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot decode message") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - c.String(http.StatusInternalServerError, "Very sad!") + ctx.String(http.StatusInternalServerError, "Very sad!") + return } now := time.Now().UTC() - m, err := h.GetMachineByMachineKey(mKey.HexString()) + machine, err := h.GetMachineByMachineKey(machineKey.HexString()) if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") newMachine := Machine{ Expiry: &time.Time{}, - MachineKey: mKey.HexString(), + MachineKey: machineKey.HexString(), Name: req.Hostinfo.Hostname, } if err := h.db.Create(&newMachine).Error; err != nil { @@ -95,88 +103,96 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { Str("handler", "Registration"). Err(err). Msg("Could not create row") - machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc() + machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name). + Inc() + return } - m = &newMachine + machine = &newMachine } - if !m.Registered && req.Auth.AuthKey != "" { - h.handleAuthKey(c, h.db, mKey, req, *m) + if !machine.Registered && req.Auth.AuthKey != "" { + h.handleAuthKey(ctx, h.db, machineKey, req, *machine) + return } resp := tailcfg.RegisterResponse{} // We have the updated key! - if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { - + if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() { // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { log.Info(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Client requested logout") - m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired - h.db.Save(&m) + machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired + h.db.Save(&machine) resp.AuthURL = "" resp.MachineAuthorized = false - resp.User = *m.Namespace.toUser() - respBody, err := encode(resp, &mKey, h.privateKey) + resp.User = *machine.Namespace.toUser() + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") + ctx.String(http.StatusInternalServerError, "") + return } - c.Data(200, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + return } - if m.Registered && m.Expiry.UTC().After(now) { + if machine.Registered && machine.Expiry.UTC().After(now) { // The machine registration is valid, respond with redirect to /map log.Debug(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Client is registered and we have the current NodeKey. All clear to /map") resp.AuthURL = "" resp.MachineAuthorized = true - resp.User = *m.Namespace.toUser() - resp.Login = *m.Namespace.toLogin() + resp.User = *machine.Namespace.toUser() + resp.Login = *machine.Namespace.toLogin() - respBody, err := encode(resp, &mKey, h.privateKey) + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).Inc() - c.String(http.StatusInternalServerError, "") + machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name). + Inc() + ctx.String(http.StatusInternalServerError, "") + return } - machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).Inc() - c.Data(200, "application/json; charset=utf-8", respBody) + machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name). + Inc() + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + return } // The client has registered before, but has expired log.Debug(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Machine registration has expired. Sending a authurl to register") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) } // When a client connects, it may request a specific expiry time in its @@ -185,102 +201,120 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) { // into two steps (which cant pass arbitrary data between them easily) and needs to be // retrieved again after the user has authenticated. After the authentication flow // completes, RequestedExpiry is copied into Expiry. - m.RequestedExpiry = &req.Expiry + machine.RequestedExpiry = &req.Expiry - h.db.Save(&m) + h.db.Save(&machine) - respBody, err := encode(resp, &mKey, h.privateKey) + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).Inc() - c.String(http.StatusInternalServerError, "") + machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name). + Inc() + ctx.String(http.StatusInternalServerError, "") + return } - machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).Inc() - c.Data(200, "application/json; charset=utf-8", respBody) + machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name). + Inc() + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + return } // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) { + if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && + machine.Expiry.UTC().After(now) { log.Debug(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("We have the OldNodeKey in the database. This is a key refresh") - m.NodeKey = wgkey.Key(req.NodeKey).HexString() - h.db.Save(&m) + machine.NodeKey = wgkey.Key(req.NodeKey).HexString() + h.db.Save(&machine) resp.AuthURL = "" - resp.User = *m.Namespace.toUser() - respBody, err := encode(resp, &mKey, h.privateKey) + resp.User = *machine.Namespace.toUser() + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "Extremely sad!") + ctx.String(http.StatusInternalServerError, "Extremely sad!") + return } - c.Data(200, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) + return } // The machine registration is new, redirect the client to the registration URL log.Debug(). Str("handler", "Registration"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("The node is sending us a new NodeKey, sending auth url") if h.cfg.OIDC.Issuer != "" { - resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + resp.AuthURL = fmt.Sprintf( + "%s/oidc/register/%s", + strings.TrimSuffix(h.cfg.ServerURL, "/"), + machineKey.HexString(), + ) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) + strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString()) } // save the requested expiry time for retrieval later in the authentication flow - m.RequestedExpiry = &req.Expiry - m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey - h.db.Save(&m) + machine.RequestedExpiry = &req.Expiry + machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey + h.db.Save(&machine) - respBody, err := encode(resp, &mKey, h.privateKey) + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Str("handler", "Registration"). Err(err). Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") + ctx.String(http.StatusInternalServerError, "") + return } - c.Data(200, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } -func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) { +func (h *Headscale) getMapResponse( + machineKey wgkey.Key, + req tailcfg.MapRequest, + machine *Machine, +) ([]byte, error) { log.Trace(). Str("func", "getMapResponse"). Str("machine", req.Hostinfo.Hostname). Msg("Creating Map response") - node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) + node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { log.Error(). Str("func", "getMapResponse"). Err(err). Msg("Cannot convert to node") + return nil, err } - peers, err := h.getPeers(m) + peers, err := h.getPeers(machine) if err != nil { log.Error(). Str("func", "getMapResponse"). Err(err). Msg("Cannot fetch peers") + return nil, err } - profiles := getMapResponseUserProfiles(*m, peers) + profiles := getMapResponseUserProfiles(*machine, peers) nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { @@ -288,17 +322,16 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma Str("func", "getMapResponse"). Err(err). Msg("Failed to convert peers to Tailscale nodes") + return nil, err } - dnsConfig, err := getMapResponseDNSConfig(h.cfg.DNSConfig, h.cfg.BaseDomain, *m, peers) - if err != nil { - log.Error(). - Str("func", "getMapResponse"). - Err(err). - Msg("Failed generate the DNSConfig") - return nil, err - } + dnsConfig := getMapResponseDNSConfig( + h.cfg.DNSConfig, + h.cfg.BaseDomain, + *machine, + peers, + ) resp := tailcfg.MapResponse{ KeepAlive: false, @@ -323,66 +356,71 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) - respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey) + respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey) if err != nil { return nil, err } } else { - respBody, err = encode(resp, &mKey, h.privateKey) + respBody, err = encode(resp, &machineKey, h.privateKey) if err != nil { return nil, err } } // declare the incoming size on the first 4 bytes - data := make([]byte, 4) + data := make([]byte, reservedResponseHeaderSize) binary.LittleEndian.PutUint32(data, uint32(len(respBody))) data = append(data, respBody...) + return data, nil } -func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) { - resp := tailcfg.MapResponse{ +func (h *Headscale) getMapKeepAliveResponse( + machineKey wgkey.Key, + mapRequest tailcfg.MapRequest, +) ([]byte, error) { + mapResponse := tailcfg.MapResponse{ KeepAlive: true, } var respBody []byte var err error - if req.Compress == "zstd" { - src, _ := json.Marshal(resp) + if mapRequest.Compress == "zstd" { + src, _ := json.Marshal(mapResponse) encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) - respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey) + respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey) if err != nil { return nil, err } } else { - respBody, err = encode(resp, &mKey, h.privateKey) + respBody, err = encode(mapResponse, &machineKey, h.privateKey) if err != nil { return nil, err } } - data := make([]byte, 4) + data := make([]byte, reservedResponseHeaderSize) binary.LittleEndian.PutUint32(data, uint32(len(respBody))) data = append(data, respBody...) + return data, nil } func (h *Headscale) handleAuthKey( - c *gin.Context, + ctx *gin.Context, db *gorm.DB, idKey wgkey.Key, - req tailcfg.RegisterRequest, - m Machine, + reqisterRequest tailcfg.RegisterRequest, + machine Machine, ) { log.Debug(). Str("func", "handleAuthKey"). - Str("machine", req.Hostinfo.Hostname). - Msgf("Processing auth key for %s", req.Hostinfo.Hostname) + Str("machine", reqisterRequest.Hostinfo.Hostname). + Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} - pak, err := h.checkKeyValidity(req.Auth.AuthKey) + pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey) if err != nil { log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Err(err). Msg("Failed authentication via AuthKey") resp.MachineAuthorized = false @@ -390,48 +428,56 @@ func (h *Headscale) handleAuthKey( if err != nil { log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Err(err). Msg("Cannot encode message") - c.String(http.StatusInternalServerError, "") - machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc() + ctx.String(http.StatusInternalServerError, "") + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). + Inc() + return } - c.Data(401, "application/json; charset=utf-8", respBody) + ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Failed authentication via AuthKey") - machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc() + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). + Inc() + return } log.Debug(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Authentication key was valid, proceeding to acquire an IP address") ip, err := h.getAvailableIP() if err != nil { log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Failed to find an available IP") - machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc() + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). + Inc() + return } log.Info(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("ip", ip.String()). - Msgf("Assigning %s to %s", ip, m.Name) + Msgf("Assigning %s to %s", ip, machine.Name) - m.AuthKeyID = uint(pak.ID) - m.IPAddress = ip.String() - m.NamespaceID = pak.NamespaceID - m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case - m.Registered = true - m.RegisterMethod = "authKey" - db.Save(&m) + machine.AuthKeyID = uint(pak.ID) + machine.IPAddress = ip.String() + machine.NamespaceID = pak.NamespaceID + machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey). + HexString() + // we update it just in case + machine.Registered = true + machine.RegisterMethod = "authKey" + db.Save(&machine) pak.Used = true db.Save(&pak) @@ -442,18 +488,21 @@ func (h *Headscale) handleAuthKey( if err != nil { log.Error(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc() - c.String(http.StatusInternalServerError, "Extremely sad!") + machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). + Inc() + ctx.String(http.StatusInternalServerError, "Extremely sad!") + return } - machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).Inc() - c.Data(200, "application/json; charset=utf-8", respBody) + machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name). + Inc() + ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) log.Info(). Str("func", "handleAuthKey"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("ip", ip.String()). Msg("Successfully authenticated via AuthKey") } diff --git a/app.go b/app.go index c226062..08b67fe 100644 --- a/app.go +++ b/app.go @@ -18,20 +18,19 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - "github.com/patrickmn/go-cache" - "golang.org/x/oauth2" - "github.com/gin-gonic/gin" - "github.com/grpc-ecosystem/go-grpc-middleware" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/philip-bui/grpc-zerolog" + "github.com/patrickmn/go-cache" + zerolog "github.com/philip-bui/grpc-zerolog" zl "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/soheilhy/cmux" ginprometheus "github.com/zsais/go-gin-prometheus" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" + "golang.org/x/oauth2" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -48,7 +47,16 @@ import ( ) const ( - AUTH_PREFIX = "Bearer " + AuthPrefix = "Bearer " + Postgres = "postgresql" + Sqlite = "sqlite3" + updateInterval = 5000 + HTTPReadTimeout = 30 * time.Second + + errUnsupportedDatabase = Error("unsupported DB") + errUnsupportedLetsEncryptChallengeType = Error( + "unknown value for Lets Encrypt challenge type", + ) ) // Config contains the initial Headscale configuration. @@ -151,16 +159,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) { var dbString string switch cfg.DBtype { - case "postgres": - dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost, - cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass) - case "sqlite3": + case Postgres: + dbString = fmt.Sprintf( + "host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", + cfg.DBhost, + cfg.DBport, + cfg.DBname, + cfg.DBuser, + cfg.DBpass, + ) + case Sqlite: dbString = cfg.DBpath default: - return nil, errors.New("unsupported DB") + return nil, errUnsupportedDatabase } - h := Headscale{ + app := Headscale{ cfg: cfg, dbType: cfg.DBtype, dbString: dbString, @@ -169,33 +183,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) { aclRules: tailcfg.FilterAllowAll, // default allowall } - err = h.initDB() + err = app.initDB() if err != nil { return nil, err } if cfg.OIDC.Issuer != "" { - err = h.initOIDC() + err = app.initOIDC() if err != nil { return nil, err } } - if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS - magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) - if err != nil { - return nil, err - } + if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS + magicDNSDomains := generateMagicDNSRootDomains( + app.cfg.IPPrefix, + ) // we might have routes already from Split DNS - if h.cfg.DNSConfig.Routes == nil { - h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) + if app.cfg.DNSConfig.Routes == nil { + app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) } for _, d := range magicDNSDomains { - h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil + app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil } } - return &h, nil + return &app, nil } // Redirect to our TLS url. @@ -221,30 +234,37 @@ func (h *Headscale) expireEphemeralNodesWorker() { return } - for _, ns := range namespaces { - machines, err := h.ListMachinesInNamespace(ns.Name) + for _, namespace := range namespaces { + machines, err := h.ListMachinesInNamespace(namespace.Name) if err != nil { - log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace") + log.Error(). + Err(err). + Str("namespace", namespace.Name). + Msg("Error listing machines in namespace") return } - for _, m := range machines { - if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && - time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { - log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database") + for _, machine := range machines { + if machine.AuthKey != nil && machine.LastSeen != nil && + machine.AuthKey.Ephemeral && + time.Now(). + After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { + log.Info(). + Str("machine", machine.Name). + Msg("Ephemeral client removed from database") - err = h.db.Unscoped().Delete(m).Error + err = h.db.Unscoped().Delete(machine).Error if err != nil { log.Error(). Err(err). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("🤮 Cannot delete ephemeral machine from the database") } } } - h.setLastStateChangeToNow(ns.Name) + h.setLastStateChangeToNow(namespace.Name) } } @@ -266,36 +286,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - // Check if the request is coming from the on-server client. // This is not secure, but it is to maintain maintainability // with the "legacy" database-based client // It is also neede for grpc-gateway to be able to connect to // the server - p, _ := peer.FromContext(ctx) + client, _ := peer.FromContext(ctx) - log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client is trying to authenticate") + log.Trace(). + Caller(). + Str("client_address", client.Addr.String()). + Msg("Client is trying to authenticate") - md, ok := metadata.FromIncomingContext(ctx) + meta, ok := metadata.FromIncomingContext(ctx) if !ok { - log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Retrieving metadata is failed") - return ctx, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed") + log.Error(). + Caller(). + Str("client_address", client.Addr.String()). + Msg("Retrieving metadata is failed") + + return ctx, status.Errorf( + codes.InvalidArgument, + "Retrieving metadata is failed", + ) } - authHeader, ok := md["authorization"] + authHeader, ok := meta["authorization"] if !ok { - log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Authorization token is not supplied") - return ctx, status.Errorf(codes.Unauthenticated, "Authorization token is not supplied") + log.Error(). + Caller(). + Str("client_address", client.Addr.String()). + Msg("Authorization token is not supplied") + + return ctx, status.Errorf( + codes.Unauthenticated, + "Authorization token is not supplied", + ) } token := authHeader[0] - if !strings.HasPrefix(token, AUTH_PREFIX) { + if !strings.HasPrefix(token, AuthPrefix) { log.Error(). Caller(). - Str("client_address", p.Addr.String()). + Str("client_address", client.Addr.String()). Msg(`missing "Bearer " prefix in "Authorization" header`) - return ctx, status.Error(codes.Unauthenticated, `missing "Bearer " prefix in "Authorization" header`) + + return ctx, status.Error( + codes.Unauthenticated, + `missing "Bearer " prefix in "Authorization" header`, + ) } // TODO(kradalby): Implement API key backend: @@ -307,35 +347,38 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, // Currently all other than localhost traffic is unauthorized, this is intentional to allow // us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities // and API key auth - return ctx, status.Error(codes.Unauthenticated, "Authentication is not implemented yet") + return ctx, status.Error( + codes.Unauthenticated, + "Authentication is not implemented yet", + ) - //if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token { - // log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token") - // return ctx, status.Error(codes.Unauthenticated, "invalid token") - //} + // if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token { + // log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token") + // return ctx, status.Error(codes.Unauthenticated, "invalid token") + // } // return handler(ctx, req) } -func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) { +func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) { log.Trace(). Caller(). - Str("client_address", c.ClientIP()). + Str("client_address", ctx.ClientIP()). Msg("HTTP authentication invoked") - authHeader := c.GetHeader("authorization") + authHeader := ctx.GetHeader("authorization") - if !strings.HasPrefix(authHeader, AUTH_PREFIX) { + if !strings.HasPrefix(authHeader, AuthPrefix) { log.Error(). Caller(). - Str("client_address", c.ClientIP()). + Str("client_address", ctx.ClientIP()). Msg(`missing "Bearer " prefix in "Authorization" header`) - c.AbortWithStatus(http.StatusUnauthorized) + ctx.AbortWithStatus(http.StatusUnauthorized) return } - c.AbortWithStatus(http.StatusUnauthorized) + ctx.AbortWithStatus(http.StatusUnauthorized) // TODO(kradalby): Implement API key backend // Currently all traffic is unauthorized, this is intentional to allow @@ -359,6 +402,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error { if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) { return nil } + return os.Remove(h.cfg.UnixSocket) } @@ -401,14 +445,17 @@ func (h *Headscale) Serve() error { // Create the cmux object that will multiplex 2 protocols on the same port. // The two following listeners will be served on the same port below gracefully. - m := cmux.New(networkListener) + networkMutex := cmux.New(networkListener) // Match gRPC requests here - grpcListener := m.MatchWithWriters( + grpcListener := networkMutex.MatchWithWriters( cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"), - cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc+proto"), + cmux.HTTP2MatchHeaderFieldSendSettings( + "content-type", + "application/grpc+proto", + ), ) // Otherwise match regular http requests. - httpListener := m.Match(cmux.Any()) + httpListener := networkMutex.Match(cmux.Any()) grpcGatewayMux := runtime.NewServeMux() @@ -431,30 +478,33 @@ func (h *Headscale) Serve() error { return err } - r := gin.Default() + router := gin.Default() - p := ginprometheus.NewPrometheus("gin") - p.Use(r) + prometheus := ginprometheus.NewPrometheus("gin") + prometheus.Use(router) - r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }) - r.GET("/key", h.KeyHandler) - r.GET("/register", h.RegisterWebAPI) - r.POST("/machine/:id/map", h.PollNetMapHandler) - r.POST("/machine/:id", h.RegistrationHandler) - r.GET("/oidc/register/:mkey", h.RegisterOIDC) - r.GET("/oidc/callback", h.OIDCCallback) - r.GET("/apple", h.AppleMobileConfig) - r.GET("/apple/:platform", h.ApplePlatformConfig) - r.GET("/swagger", SwaggerUI) - r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1) + router.GET( + "/health", + func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, + ) + router.GET("/key", h.KeyHandler) + router.GET("/register", h.RegisterWebAPI) + router.POST("/machine/:id/map", h.PollNetMapHandler) + router.POST("/machine/:id", h.RegistrationHandler) + router.GET("/oidc/register/:mkey", h.RegisterOIDC) + router.GET("/oidc/callback", h.OIDCCallback) + router.GET("/apple", h.AppleMobileConfig) + router.GET("/apple/:platform", h.ApplePlatformConfig) + router.GET("/swagger", SwaggerUI) + router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1) - api := r.Group("/api") + api := router.Group("/api") api.Use(h.httpAuthenticationMiddleware) { api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP)) } - r.NoRoute(stdoutHandler) + router.NoRoute(stdoutHandler) // Fetch an initial DERP Map before we start serving h.DERPMap = GetDERPMap(h.cfg.DERP) @@ -466,14 +516,13 @@ func (h *Headscale) Serve() error { } // I HATE THIS - updateMillisecondsWait := int64(5000) - go h.watchForKVUpdates(updateMillisecondsWait) - go h.expireEphemeralNodes(updateMillisecondsWait) + go h.watchForKVUpdates(updateInterval) + go h.expireEphemeralNodes(updateInterval) httpServer := &http.Server{ Addr: h.cfg.Addr, - Handler: r, - ReadTimeout: 30 * time.Second, + Handler: router, + ReadTimeout: HTTPReadTimeout, // Go does not handle timeouts in HTTP very well, and there is // no good way to handle streaming timeouts, therefore we need to // keep this at unlimited and be careful to clean up connections @@ -519,36 +568,40 @@ func (h *Headscale) Serve() error { reflection.Register(grpcServer) reflection.Register(grpcSocket) - g := new(errgroup.Group) + errorGroup := new(errgroup.Group) - g.Go(func() error { return grpcSocket.Serve(socketListener) }) + errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) }) // TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP - g.Go(func() error { return grpcServer.Serve(grpcListener) }) + errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) }) if tlsConfig != nil { - g.Go(func() error { + errorGroup.Go(func() error { tlsl := tls.NewListener(httpListener, tlsConfig) + return httpServer.Serve(tlsl) }) } else { - g.Go(func() error { return httpServer.Serve(httpListener) }) + errorGroup.Go(func() error { return httpServer.Serve(httpListener) }) } - g.Go(func() error { return m.Serve() }) + errorGroup.Go(func() error { return networkMutex.Serve() }) - log.Info().Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr) + log.Info(). + Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr) - return g.Wait() + return errorGroup.Wait() } func (h *Headscale) getTLSSettings() (*tls.Config, error) { + var err error if h.cfg.TLSLetsEncryptHostname != "" { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { - log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") + log.Warn(). + Msg("Listening with TLS but ServerURL does not start with https://") } - m := autocert.Manager{ + certManager := autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname), Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir), @@ -558,40 +611,44 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { Email: h.cfg.ACMEEmail, } - if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" { + switch h.cfg.TLSLetsEncryptChallengeType { + case "TLS-ALPN-01": // Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737) // The RFC requires that the validation is done on port 443; in other words, headscale // must be reachable on port 443. - return m.TLSConfig(), nil - } else if h.cfg.TLSLetsEncryptChallengeType == "HTTP-01" { + return certManager.TLSConfig(), nil + + case "HTTP-01": // Configuration via autocert with HTTP-01. This requires listening on // port 80 for the certificate validation in addition to the headscale // service, which can be configured to run on any other port. go func() { log.Fatal(). - Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))). + Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))). Msg("failed to set up a HTTP server") }() - return m.TLSConfig(), nil - } else { - return nil, errors.New("unknown value for TLSLetsEncryptChallengeType") + return certManager.TLSConfig(), nil + + default: + return nil, errUnsupportedLetsEncryptChallengeType } } else if h.cfg.TLSCertPath == "" { if !strings.HasPrefix(h.cfg.ServerURL, "http://") { log.Warn().Msg("Listening without TLS but ServerURL does not start with http://") } - return nil, nil + return nil, err } else { if !strings.HasPrefix(h.cfg.ServerURL, "https://") { log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") } - var err error - tlsConfig := &tls.Config{} - tlsConfig.ClientAuth = tls.RequireAnyClientCert - tlsConfig.NextProtos = []string{"http/1.1"} - tlsConfig.Certificates = make([]tls.Certificate, 1) + tlsConfig := &tls.Config{ + ClientAuth: tls.RequireAnyClientCert, + NextProtos: []string{"http/1.1"}, + Certificates: make([]tls.Certificate, 1), + MinVersion: tls.VersionTLS12, + } tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath) return tlsConfig, err @@ -628,13 +685,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { } } -func stdoutHandler(c *gin.Context) { - b, _ := io.ReadAll(c.Request.Body) +func stdoutHandler(ctx *gin.Context) { + body, _ := io.ReadAll(ctx.Request.Body) log.Trace(). - Interface("header", c.Request.Header). - Interface("proto", c.Request.Proto). - Interface("url", c.Request.URL). - Bytes("body", b). + Interface("header", ctx.Request.Header). + Interface("proto", ctx.Request.Proto). + Interface("url", ctx.Request.URL). + Bytes("body", body). Msg("Request did not match") } diff --git a/app_test.go b/app_test.go index 5e53f1c..947062b 100644 --- a/app_test.go +++ b/app_test.go @@ -17,8 +17,10 @@ var _ = check.Suite(&Suite{}) type Suite struct{} -var tmpDir string -var h Headscale +var ( + tmpDir string + app Headscale +) func (s *Suite) SetUpTest(c *check.C) { s.ResetDB(c) @@ -41,18 +43,18 @@ func (s *Suite) ResetDB(c *check.C) { IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"), } - h = Headscale{ + app = Headscale{ cfg: cfg, dbType: "sqlite3", dbString: tmpDir + "/headscale_test.db", } - err = h.initDB() + err = app.initDB() if err != nil { c.Fatal(err) } - db, err := h.openDB() + db, err := app.openDB() if err != nil { c.Fatal(err) } - h.db = db + app.db = db } diff --git a/apple_mobileconfig.go b/apple_mobileconfig.go index f3956e3..2e454df 100644 --- a/apple_mobileconfig.go +++ b/apple_mobileconfig.go @@ -5,16 +5,15 @@ import ( "net/http" "text/template" - "github.com/rs/zerolog/log" - "github.com/gin-gonic/gin" "github.com/gofrs/uuid" + "github.com/rs/zerolog/log" ) // AppleMobileConfig shows a simple message in the browser to point to the CLI -// Listens in /register -func (h *Headscale) AppleMobileConfig(c *gin.Context) { - t := template.Must(template.New("apple").Parse(` +// Listens in /register. +func (h *Headscale) AppleMobileConfig(ctx *gin.Context) { + appleTemplate := template.Must(template.New("apple").Parse(`

Apple configuration profiles

@@ -56,7 +55,7 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {

Or

Use your terminal to configure the default setting for Tailscale by issuing:

- defaults write io.tailscale.ipn.macos ControlURL {{.Url}} + defaults write io.tailscale.ipn.macos ControlURL {{.URL}}

Restart Tailscale.app and log in.

@@ -64,24 +63,29 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) { `)) config := map[string]interface{}{ - "Url": h.cfg.ServerURL, + "URL": h.cfg.ServerURL, } var payload bytes.Buffer - if err := t.Execute(&payload, config); err != nil { + if err := appleTemplate.Execute(&payload, config); err != nil { log.Error(). Str("handler", "AppleMobileConfig"). Err(err). Msg("Could not render Apple index template") - c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple index template")) + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Could not render Apple index template"), + ) + return } - c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) + ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) } -func (h *Headscale) ApplePlatformConfig(c *gin.Context) { - platform := c.Param("platform") +func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) { + platform := ctx.Param("platform") id, err := uuid.NewV4() if err != nil { @@ -89,23 +93,33 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Failed not create UUID") - c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID")) + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Failed to create UUID"), + ) + return } - contentId, err := uuid.NewV4() + contentID, err := uuid.NewV4() if err != nil { log.Error(). Str("handler", "ApplePlatformConfig"). Err(err). Msg("Failed not create UUID") - c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID")) + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Failed to create UUID"), + ) + return } platformConfig := AppleMobilePlatformConfig{ - UUID: contentId, - Url: h.cfg.ServerURL, + UUID: contentID, + URL: h.cfg.ServerURL, } var payload bytes.Buffer @@ -117,7 +131,12 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple macOS template") - c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple macOS template")) + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Could not render Apple macOS template"), + ) + return } case "ios": @@ -126,17 +145,27 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple iOS template") - c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple iOS template")) + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Could not render Apple iOS template"), + ) + return } default: - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte("Invalid platform, only ios and macos is supported")) + ctx.Data( + http.StatusOK, + "text/html; charset=utf-8", + []byte("Invalid platform, only ios and macos is supported"), + ) + return } config := AppleMobileConfig{ UUID: id, - Url: h.cfg.ServerURL, + URL: h.cfg.ServerURL, Payload: payload.String(), } @@ -146,25 +175,35 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) { Str("handler", "ApplePlatformConfig"). Err(err). Msg("Could not render Apple platform template") - c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple platform template")) + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Could not render Apple platform template"), + ) + return } - c.Data(http.StatusOK, "application/x-apple-aspen-config; charset=utf-8", content.Bytes()) + ctx.Data( + http.StatusOK, + "application/x-apple-aspen-config; charset=utf-8", + content.Bytes(), + ) } type AppleMobileConfig struct { UUID uuid.UUID - Url string + URL string Payload string } type AppleMobilePlatformConfig struct { UUID uuid.UUID - Url string + URL string } -var commonTemplate = template.Must(template.New("mobileconfig").Parse(` +var commonTemplate = template.Must( + template.New("mobileconfig").Parse(` @@ -173,7 +212,7 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`PayloadDisplayName Headscale PayloadDescription - Configure Tailscale login server to: {{.Url}} + Configure Tailscale login server to: {{.URL}} PayloadIdentifier com.github.juanfont.headscale PayloadRemovalDisallowed @@ -187,7 +226,8 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(` -`)) +`), +) var iosTemplate = template.Must(template.New("iosTemplate").Parse(` @@ -203,7 +243,7 @@ var iosTemplate = template.Must(template.New("iosTemplate").Parse(` ControlURL - {{.Url}} + {{.URL}} `)) @@ -221,6 +261,6 @@ var macosTemplate = template.Must(template.New("macosTemplate").Parse(` ControlURL - {{.Url}} + {{.URL}} `)) diff --git a/cli_test.go b/cli_test.go index 291b5df..44ef9f0 100644 --- a/cli_test.go +++ b/cli_test.go @@ -7,31 +7,34 @@ import ( ) func (s *Suite) TestRegisterMachine(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) now := time.Now().UTC() - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, IPAddress: "10.0.0.1", Expiry: &now, RequestedExpiry: &now, } - h.db.Save(&m) + app.db.Save(&machine) - _, err = h.GetMachine("test", "testmachine") + _, err = app.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) - m2, err := h.RegisterMachine("8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", n.Name) + machineAfterRegistering, err := app.RegisterMachine( + "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", + namespace.Name, + ) c.Assert(err, check.IsNil) - c.Assert(m2.Registered, check.Equals, true) + c.Assert(machineAfterRegistering.Registered, check.Equals, true) - _, err = m2.GetHostInfo() + _, err = machineAfterRegistering.GetHostInfo() c.Assert(err, check.IsNil) } diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index e140156..46bdb9e 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -27,7 +27,8 @@ func init() { if err != nil { log.Fatal().Err(err).Msg("") } - createNodeCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise") + createNodeCmd.Flags(). + StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise") debugCmd.AddCommand(createNodeCmd) } @@ -47,6 +48,7 @@ var createNodeCmd = &cobra.Command{ namespace, err := cmd.Flags().GetString("namespace") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) + return } @@ -56,19 +58,34 @@ var createNodeCmd = &cobra.Command{ name, err := cmd.Flags().GetString("name") if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting node from flag: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting node from flag: %s", err), + output, + ) + return } machineKey, err := cmd.Flags().GetString("key") if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting key from flag: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting key from flag: %s", err), + output, + ) + return } routes, err := cmd.Flags().GetStringSlice("route") if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting routes from flag: %s", err), + output, + ) + return } @@ -81,7 +98,12 @@ var createNodeCmd = &cobra.Command{ response, err := client.DebugCreateMachine(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), + output, + ) + return } diff --git a/cmd/headscale/cli/namespaces.go b/cmd/headscale/cli/namespaces.go index 0e5eeb4..361e9be 100644 --- a/cmd/headscale/cli/namespaces.go +++ b/cmd/headscale/cli/namespaces.go @@ -4,6 +4,7 @@ import ( "fmt" survey "github.com/AlecAivazis/survey/v2" + "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/pterm/pterm" "github.com/rs/zerolog/log" @@ -19,6 +20,10 @@ func init() { namespaceCmd.AddCommand(renameNamespaceCmd) } +const ( + errMissingParameter = headscale.Error("missing parameters") +) + var namespaceCmd = &cobra.Command{ Use: "namespaces", Short: "Manage the namespaces of Headscale", @@ -29,8 +34,9 @@ var createNamespaceCmd = &cobra.Command{ Short: "Creates a new namespace", Args: func(cmd *cobra.Command, args []string) error { if len(args) < 1 { - return fmt.Errorf("Missing parameters") + return errMissingParameter } + return nil }, Run: func(cmd *cobra.Command, args []string) { @@ -49,7 +55,15 @@ var createNamespaceCmd = &cobra.Command{ log.Trace().Interface("request", request).Msg("Sending CreateNamespace request") response, err := client.CreateNamespace(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot create namespace: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf( + "Cannot create namespace: %s", + status.Convert(err).Message(), + ), + output, + ) + return } @@ -62,8 +76,9 @@ var destroyNamespaceCmd = &cobra.Command{ Short: "Destroys a namespace", Args: func(cmd *cobra.Command, args []string) error { if len(args) < 1 { - return fmt.Errorf("Missing parameters") + return errMissingParameter } + return nil }, Run: func(cmd *cobra.Command, args []string) { @@ -81,7 +96,12 @@ var destroyNamespaceCmd = &cobra.Command{ _, err := client.GetNamespace(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Error: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Error: %s", status.Convert(err).Message()), + output, + ) + return } @@ -89,7 +109,10 @@ var destroyNamespaceCmd = &cobra.Command{ force, _ := cmd.Flags().GetBool("force") if !force { prompt := &survey.Confirm{ - Message: fmt.Sprintf("Do you want to remove the namespace '%s' and any associated preauthkeys?", namespaceName), + Message: fmt.Sprintf( + "Do you want to remove the namespace '%s' and any associated preauthkeys?", + namespaceName, + ), } err := survey.AskOne(prompt, &confirm) if err != nil { @@ -102,7 +125,15 @@ var destroyNamespaceCmd = &cobra.Command{ response, err := client.DeleteNamespace(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot destroy namespace: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf( + "Cannot destroy namespace: %s", + status.Convert(err).Message(), + ), + output, + ) + return } SuccessOutput(response, "Namespace destroyed", output) @@ -126,19 +157,25 @@ var listNamespacesCmd = &cobra.Command{ response, err := client.ListNamespaces(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), + output, + ) + return } if output != "" { SuccessOutput(response.Namespaces, "", output) + return } - d := pterm.TableData{{"ID", "Name", "Created"}} + tableData := pterm.TableData{{"ID", "Name", "Created"}} for _, namespace := range response.GetNamespaces() { - d = append( - d, + tableData = append( + tableData, []string{ namespace.GetId(), namespace.GetName(), @@ -146,9 +183,14 @@ var listNamespacesCmd = &cobra.Command{ }, ) } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { - ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return } }, @@ -158,9 +200,11 @@ var renameNamespaceCmd = &cobra.Command{ Use: "rename OLD_NAME NEW_NAME", Short: "Renames a namespace", Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 2 { - return fmt.Errorf("Missing parameters") + expectedArguments := 2 + if len(args) < expectedArguments { + return errMissingParameter } + return nil }, Run: func(cmd *cobra.Command, args []string) { @@ -177,7 +221,15 @@ var renameNamespaceCmd = &cobra.Command{ response, err := client.RenameNamespace(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot rename namespace: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf( + "Cannot rename namespace: %s", + status.Convert(err).Message(), + ), + output, + ) + return } diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 9b3f424..218d25b 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -7,6 +7,7 @@ import ( "time" survey "github.com/AlecAivazis/survey/v2" + "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/pterm/pterm" "github.com/spf13/cobra" @@ -77,6 +78,7 @@ var registerNodeCmd = &cobra.Command{ namespace, err := cmd.Flags().GetString("namespace") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) + return } @@ -86,7 +88,12 @@ var registerNodeCmd = &cobra.Command{ machineKey, err := cmd.Flags().GetString("key") if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting machine key from flag: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting machine key from flag: %s", err), + output, + ) + return } @@ -97,7 +104,15 @@ var registerNodeCmd = &cobra.Command{ response, err := client.RegisterMachine(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf( + "Cannot register machine: %s\n", + status.Convert(err).Message(), + ), + output, + ) + return } @@ -113,6 +128,7 @@ var listNodesCmd = &cobra.Command{ namespace, err := cmd.Flags().GetString("namespace") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) + return } @@ -126,24 +142,36 @@ var listNodesCmd = &cobra.Command{ response, err := client.ListMachines(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), + output, + ) + return } if output != "" { SuccessOutput(response.Machines, "", output) + return } - d, err := nodesToPtables(namespace, response.Machines) + tableData, err := nodesToPtables(namespace, response.Machines) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) + return } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { - ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return } }, @@ -155,9 +183,14 @@ var deleteNodeCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - id, err := cmd.Flags().GetInt("identifier") + identifier, err := cmd.Flags().GetInt("identifier") if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error converting ID to integer: %s", err), + output, + ) + return } @@ -166,24 +199,35 @@ var deleteNodeCmd = &cobra.Command{ defer conn.Close() getRequest := &v1.GetMachineRequest{ - MachineId: uint64(id), + MachineId: uint64(identifier), } getResponse, err := client.GetMachine(ctx, getRequest) if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf( + "Error getting node node: %s", + status.Convert(err).Message(), + ), + output, + ) + return } deleteRequest := &v1.DeleteMachineRequest{ - MachineId: uint64(id), + MachineId: uint64(identifier), } confirm := false force, _ := cmd.Flags().GetBool("force") if !force { prompt := &survey.Confirm{ - Message: fmt.Sprintf("Do you want to remove the node %s?", getResponse.GetMachine().Name), + Message: fmt.Sprintf( + "Do you want to remove the node %s?", + getResponse.GetMachine().Name, + ), } err = survey.AskOne(prompt, &confirm) if err != nil { @@ -195,13 +239,26 @@ var deleteNodeCmd = &cobra.Command{ response, err := client.DeleteMachine(ctx, deleteRequest) if output != "" { SuccessOutput(response, "", output) + return } if err != nil { - ErrorOutput(err, fmt.Sprintf("Error deleting node: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf( + "Error deleting node: %s", + status.Convert(err).Message(), + ), + output, + ) + return } - SuccessOutput(map[string]string{"Result": "Node deleted"}, "Node deleted", output) + SuccessOutput( + map[string]string{"Result": "Node deleted"}, + "Node deleted", + output, + ) } else { SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output) } @@ -210,12 +267,12 @@ var deleteNodeCmd = &cobra.Command{ func sharingWorker( cmd *cobra.Command, - args []string, ) (string, *v1.Machine, *v1.Namespace, error) { output, _ := cmd.Flags().GetString("output") namespaceStr, err := cmd.Flags().GetString("namespace") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) + return "", nil, nil, err } @@ -223,19 +280,25 @@ func sharingWorker( defer cancel() defer conn.Close() - id, err := cmd.Flags().GetInt("identifier") + identifier, err := cmd.Flags().GetInt("identifier") if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output) + return "", nil, nil, err } machineRequest := &v1.GetMachineRequest{ - MachineId: uint64(id), + MachineId: uint64(identifier), } machineResponse, err := client.GetMachine(ctx, machineRequest) if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), + output, + ) + return "", nil, nil, err } @@ -245,7 +308,12 @@ func sharingWorker( namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest) if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), + output, + ) + return "", nil, nil, err } @@ -256,9 +324,14 @@ var shareMachineCmd = &cobra.Command{ Use: "share", Short: "Shares a node from the current namespace to the specified one", Run: func(cmd *cobra.Command, args []string) { - output, machine, namespace, err := sharingWorker(cmd, args) + output, machine, namespace, err := sharingWorker(cmd) if err != nil { - ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Failed to fetch namespace or machine: %s", err), + output, + ) + return } @@ -273,7 +346,12 @@ var shareMachineCmd = &cobra.Command{ response, err := client.ShareMachine(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), + output, + ) + return } @@ -285,9 +363,14 @@ var unshareMachineCmd = &cobra.Command{ Use: "unshare", Short: "Unshares a node from the specified namespace", Run: func(cmd *cobra.Command, args []string) { - output, machine, namespace, err := sharingWorker(cmd, args) + output, machine, namespace, err := sharingWorker(cmd) if err != nil { - ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Failed to fetch namespace or machine: %s", err), + output, + ) + return } @@ -302,7 +385,12 @@ var unshareMachineCmd = &cobra.Command{ response, err := client.UnshareMachine(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), + output, + ) + return } @@ -310,8 +398,22 @@ var unshareMachineCmd = &cobra.Command{ }, } -func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.TableData, error) { - d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}} +func nodesToPtables( + currentNamespace string, + machines []*v1.Machine, +) (pterm.TableData, error) { + tableData := pterm.TableData{ + { + "ID", + "Name", + "NodeKey", + "Namespace", + "IP address", + "Ephemeral", + "Last seen", + "Online", + }, + } for _, machine := range machines { var ephemeral bool @@ -331,7 +433,9 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl nodeKey := tailcfg.NodeKey(nKey) var online string - if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online + if lastSeen.After( + time.Now().Add(-5 * time.Minute), + ) { // TODO: Find a better way to reliably show if online online = pterm.LightGreen("true") } else { online = pterm.LightRed("false") @@ -344,10 +448,10 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl // Shared into this namespace namespace = pterm.LightYellow(machine.Namespace.Name) } - d = append( - d, + tableData = append( + tableData, []string{ - strconv.FormatUint(machine.Id, 10), + strconv.FormatUint(machine.Id, headscale.Base10), machine.Name, nodeKey.ShortString(), namespace, @@ -358,5 +462,6 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl }, ) } - return d, nil + + return tableData, nil } diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index a07e961..9d5c838 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -12,6 +12,10 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +const ( + DefaultPreAuthKeyExpiry = 24 * time.Hour +) + func init() { rootCmd.AddCommand(preauthkeysCmd) preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace") @@ -22,10 +26,12 @@ func init() { preauthkeysCmd.AddCommand(listPreAuthKeys) preauthkeysCmd.AddCommand(createPreAuthKeyCmd) preauthkeysCmd.AddCommand(expirePreAuthKeyCmd) - createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable") - createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes") + createPreAuthKeyCmd.PersistentFlags(). + Bool("reusable", false, "Make the preauthkey reusable") + createPreAuthKeyCmd.PersistentFlags(). + Bool("ephemeral", false, "Preauthkey for ephemeral nodes") createPreAuthKeyCmd.Flags(). - DurationP("expiration", "e", 24*time.Hour, "Human-readable expiration of the key (30m, 24h, 365d...)") + DurationP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (30m, 24h, 365d...)") } var preauthkeysCmd = &cobra.Command{ @@ -39,9 +45,10 @@ var listPreAuthKeys = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - n, err := cmd.Flags().GetString("namespace") + namespace, err := cmd.Flags().GetString("namespace") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) + return } @@ -50,48 +57,61 @@ var listPreAuthKeys = &cobra.Command{ defer conn.Close() request := &v1.ListPreAuthKeysRequest{ - Namespace: n, + Namespace: namespace, } response, err := client.ListPreAuthKeys(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting the list of keys: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting the list of keys: %s", err), + output, + ) + return } if output != "" { SuccessOutput(response.PreAuthKeys, "", output) + return } - d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}} - for _, k := range response.PreAuthKeys { + tableData := pterm.TableData{ + {"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}, + } + for _, key := range response.PreAuthKeys { expiration := "-" - if k.GetExpiration() != nil { - expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05") + if key.GetExpiration() != nil { + expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05") } var reusable string - if k.GetEphemeral() { + if key.GetEphemeral() { reusable = "N/A" } else { - reusable = fmt.Sprintf("%v", k.GetReusable()) + reusable = fmt.Sprintf("%v", key.GetReusable()) } - d = append(d, []string{ - k.GetId(), - k.GetKey(), + tableData = append(tableData, []string{ + key.GetId(), + key.GetKey(), reusable, - strconv.FormatBool(k.GetEphemeral()), - strconv.FormatBool(k.GetUsed()), + strconv.FormatBool(key.GetEphemeral()), + strconv.FormatBool(key.GetUsed()), expiration, - k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), + key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), }) } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { - ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return } }, @@ -106,6 +126,7 @@ var createPreAuthKeyCmd = &cobra.Command{ namespace, err := cmd.Flags().GetString("namespace") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) + return } @@ -139,7 +160,12 @@ var createPreAuthKeyCmd = &cobra.Command{ response, err := client.CreatePreAuthKey(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output) + ErrorOutput( + err, + fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), + output, + ) + return } @@ -152,8 +178,9 @@ var expirePreAuthKeyCmd = &cobra.Command{ Short: "Expire a preauthkey", Args: func(cmd *cobra.Command, args []string) error { if len(args) < 1 { - return fmt.Errorf("missing parameters") + return errMissingParameter } + return nil }, Run: func(cmd *cobra.Command, args []string) { @@ -161,6 +188,7 @@ var expirePreAuthKeyCmd = &cobra.Command{ namespace, err := cmd.Flags().GetString("namespace") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output) + return } @@ -175,7 +203,12 @@ var expirePreAuthKeyCmd = &cobra.Command{ response, err := client.ExpirePreAuthKey(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output) + ErrorOutput( + err, + fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), + output, + ) + return } diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index e5115f3..99b1514 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -10,7 +10,8 @@ import ( func init() { rootCmd.PersistentFlags(). StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'") - rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution") + rootCmd.PersistentFlags(). + Bool("force", false, "Disable prompts and forces the execution") } var rootCmd = &cobra.Command{ diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go index 9f2e4e2..ced1a0b 100644 --- a/cmd/headscale/cli/routes.go +++ b/cmd/headscale/cli/routes.go @@ -21,7 +21,8 @@ func init() { } routesCmd.AddCommand(listRoutesCmd) - enableRouteCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable") + enableRouteCmd.Flags(). + StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable") enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") err = enableRouteCmd.MarkFlagRequired("identifier") if err != nil { @@ -44,9 +45,14 @@ var listRoutesCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - machineId, err := cmd.Flags().GetUint64("identifier") + machineID, err := cmd.Flags().GetUint64("identifier") if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting machine id from flag: %s", err), + output, + ) + return } @@ -55,29 +61,41 @@ var listRoutesCmd = &cobra.Command{ defer conn.Close() request := &v1.GetMachineRouteRequest{ - MachineId: machineId, + MachineId: machineID, } response, err := client.GetMachineRoute(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), + output, + ) + return } if output != "" { SuccessOutput(response.Routes, "", output) + return } - d := routesToPtables(response.Routes) + tableData := routesToPtables(response.Routes) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) + return } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { - ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return } }, @@ -93,15 +111,26 @@ omit the route you do not want to enable. `, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - machineId, err := cmd.Flags().GetUint64("identifier") + + machineID, err := cmd.Flags().GetUint64("identifier") if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting machine id from flag: %s", err), + output, + ) + return } routes, err := cmd.Flags().GetStringSlice("route") if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Error getting routes from flag: %s", err), + output, + ) + return } @@ -110,45 +139,61 @@ omit the route you do not want to enable. defer conn.Close() request := &v1.EnableMachineRoutesRequest{ - MachineId: machineId, + MachineId: machineID, Routes: routes, } response, err := client.EnableMachineRoutes(ctx, request) if err != nil { - ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output) + ErrorOutput( + err, + fmt.Sprintf( + "Cannot register machine: %s\n", + status.Convert(err).Message(), + ), + output, + ) + return } if output != "" { SuccessOutput(response.Routes, "", output) + return } - d := routesToPtables(response.Routes) + tableData := routesToPtables(response.Routes) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) + return } - err = pterm.DefaultTable.WithHasHeader().WithData(d).Render() + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { - ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output) + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return } }, } -// routesToPtables converts the list of routes to a nice table +// routesToPtables converts the list of routes to a nice table. func routesToPtables(routes *v1.Routes) pterm.TableData { - d := pterm.TableData{{"Route", "Enabled"}} + tableData := pterm.TableData{{"Route", "Enabled"}} for _, route := range routes.GetAdvertisedRoutes() { enabled := isStringInSlice(routes.EnabledRoutes, route) - d = append(d, []string{route, strconv.FormatBool(enabled)}) + tableData = append(tableData, []string{route, strconv.FormatBool(enabled)}) } - return d + + return tableData } func isStringInSlice(strs []string, s string) bool { diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 30f1a78..958fb89 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -52,9 +52,8 @@ func LoadConfig(path string) error { viper.SetDefault("cli.insecure", false) viper.SetDefault("cli.timeout", "5s") - err := viper.ReadInConfig() - if err != nil { - return fmt.Errorf("Fatal error reading config file: %s \n", err) + if err := viper.ReadInConfig(); err != nil { + return fmt.Errorf("fatal error reading config file: %w", err) } // Collect any validation errors and return them all at once @@ -82,6 +81,7 @@ func LoadConfig(path string) error { errorText += "Fatal config error: server_url must start with https:// or http://\n" } if errorText != "" { + //nolint return errors.New(strings.TrimSuffix(errorText, "\n")) } else { return nil @@ -149,9 +149,14 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) { if viper.IsSet("dns_config.restricted_nameservers") { if len(dnsConfig.Nameservers) > 0 { dnsConfig.Routes = make(map[string][]dnstype.Resolver) - restrictedDNS := viper.GetStringMapStringSlice("dns_config.restricted_nameservers") + restrictedDNS := viper.GetStringMapStringSlice( + "dns_config.restricted_nameservers", + ) for domain, restrictedNameservers := range restrictedDNS { - restrictedResolvers := make([]dnstype.Resolver, len(restrictedNameservers)) + restrictedResolvers := make( + []dnstype.Resolver, + len(restrictedNameservers), + ) for index, nameserverStr := range restrictedNameservers { nameserver, err := netaddr.ParseIP(nameserverStr) if err != nil { @@ -208,6 +213,7 @@ func absPath(path string) string { path = filepath.Join(dir, path) } } + return path } @@ -219,7 +225,9 @@ func getHeadscaleConfig() headscale.Config { "10h", ) // use 10h here because it is the length of a standard business day plus a small amount of leeway if viper.GetDuration("max_machine_registration_duration") >= time.Second { - maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") + maxMachineRegistrationDuration = viper.GetDuration( + "max_machine_registration_duration", + ) } // defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not @@ -229,7 +237,9 @@ func getHeadscaleConfig() headscale.Config { "8h", ) // use 8h here because it's the length of a standard business day if viper.GetDuration("default_machine_registration_duration") >= time.Second { - defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") + defaultMachineRegistrationDuration = viper.GetDuration( + "default_machine_registration_duration", + ) } dnsConfig, baseDomain := GetDNSConfig() @@ -244,7 +254,9 @@ func getHeadscaleConfig() headscale.Config { DERP: derpConfig, - EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"), + EphemeralNodeInactivityTimeout: viper.GetDuration( + "ephemeral_node_inactivity_timeout", + ), DBtype: viper.GetString("db_type"), DBpath: absPath(viper.GetString("db_path")), @@ -254,9 +266,11 @@ func getHeadscaleConfig() headscale.Config { DBuser: viper.GetString("db_user"), DBpass: viper.GetString("db_pass"), - TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"), - TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"), - TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")), + TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"), + TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"), + TLSLetsEncryptCacheDir: absPath( + viper.GetString("tls_letsencrypt_cache_dir"), + ), TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"), TLSCertPath: absPath(viper.GetString("tls_cert_path")), @@ -292,11 +306,14 @@ func getHeadscaleApp() (*headscale.Headscale, error) { // to avoid races minInactivityTimeout, _ := time.ParseDuration("65s") if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout { + // TODO: Find a better way to return this text + //nolint err := fmt.Errorf( - "ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n", + "ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s", viper.GetString("ephemeral_node_inactivity_timeout"), minInactivityTimeout, ) + return nil, err } @@ -304,7 +321,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { cfg.OIDC.MatchMap = loadOIDCMatchMap() - h, err := headscale.NewHeadscale(cfg) + app, err := headscale.NewHeadscale(cfg) if err != nil { return nil, err } @@ -313,7 +330,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { if viper.GetString("acl_policy_path") != "" { aclPath := absPath(viper.GetString("acl_policy_path")) - err = h.LoadACLPolicy(aclPath) + err = app.LoadACLPolicy(aclPath) if err != nil { log.Error(). Str("path", aclPath). @@ -322,7 +339,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { } } - return h, nil + return app, nil } func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { @@ -342,7 +359,6 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc. // If the address is not set, we assume that we are on the server hosting headscale. if address == "" { - log.Debug(). Str("socket", cfg.UnixSocket). Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.") @@ -402,10 +418,13 @@ func SuccessOutput(result interface{}, override string, outputFormat string) { log.Fatal().Err(err) } default: + //nolint fmt.Println(override) + return } + //nolint fmt.Println(string(j)) } @@ -423,6 +442,7 @@ func HasMachineOutputFlag() bool { return true } } + return false } @@ -431,7 +451,10 @@ type tokenAuth struct { } // Return value is mapped to request headers. -func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) { +func (t tokenAuth) GetRequestMetadata( + ctx context.Context, + in ...string, +) (map[string]string, error) { return map[string]string{ "authorization": "Bearer " + t.token, }, nil diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index ed4644f..d6bf216 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -23,6 +23,8 @@ func main() { colors = true case termcolor.LevelBasic: colors = true + case termcolor.LevelNone: + colors = false default: // no color, return text as is. colors = false @@ -41,8 +43,7 @@ func main() { NoColor: !colors, }) - err := cli.LoadConfig("") - if err != nil { + if err := cli.LoadConfig(""); err != nil { log.Fatal().Err(err) } @@ -63,13 +64,15 @@ func main() { } if !viper.GetBool("disable_check_updates") && !machineOutput { - if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" { + if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && + cli.Version != "dev" { githubTag := &latest.GithubTag{ Owner: "juanfont", Repository: "headscale", } res, err := latest.Check(githubTag, cli.Version) if err == nil && res.Outdated { + //nolint fmt.Printf( "An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n", res.Current, diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index e3a5713..1166d48 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "io/ioutil" "os" "path/filepath" @@ -40,7 +39,10 @@ func (*Suite) TestConfigLoading(c *check.C) { } // Symlink the example config file - err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml")) + err = os.Symlink( + filepath.Clean(path+"/../../config-example.yaml"), + filepath.Join(tmpDir, "config.yaml"), + ) if err != nil { c.Fatal(err) } @@ -74,7 +76,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) { } // Symlink the example config file - err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml")) + err = os.Symlink( + filepath.Clean(path+"/../../config-example.yaml"), + filepath.Join(tmpDir, "config.yaml"), + ) if err != nil { c.Fatal(err) } @@ -94,7 +99,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) { func writeConfig(c *check.C, tmpDir string, configYaml []byte) { // Populate a custom config file configFile := filepath.Join(tmpDir, "config.yaml") - err := ioutil.WriteFile(configFile, configYaml, 0o644) + err := ioutil.WriteFile(configFile, configYaml, 0o600) if err != nil { c.Fatalf("Couldn't write file %s", configFile) } @@ -106,7 +111,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) { c.Fatal(err) } // defer os.RemoveAll(tmpDir) - fmt.Println(tmpDir) configYaml := []byte( "---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"", @@ -128,8 +132,11 @@ func (*Suite) TestTLSConfigValidation(c *check.C) { check.Matches, ".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*", ) - c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*") - fmt.Println(tmp) + c.Assert( + tmp, + check.Matches, + ".*Fatal config error: server_url must start with https:// or http://.*", + ) // Check configuration validation errors (2) configYaml = []byte( diff --git a/db.go b/db.go index 42c5eee..5136325 100644 --- a/db.go +++ b/db.go @@ -9,7 +9,10 @@ import ( "gorm.io/gorm/logger" ) -const dbVersion = "1" +const ( + dbVersion = "1" + errValueNotFound = Error("not found") +) // KV is a key-value store in a psql table. For future use... type KV struct { @@ -24,7 +27,7 @@ func (h *Headscale) initDB() error { } h.db = db - if h.dbType == "postgres" { + if h.dbType == Postgres { db.Exec("create extension if not exists \"uuid-ossp\";") } err = db.AutoMigrate(&Machine{}) @@ -50,6 +53,7 @@ func (h *Headscale) initDB() error { } err = h.setValue("db_version", dbVersion) + return err } @@ -65,12 +69,12 @@ func (h *Headscale) openDB() (*gorm.DB, error) { } switch h.dbType { - case "sqlite3": + case Sqlite: db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: log, }) - case "postgres": + case Postgres: db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: log, @@ -84,28 +88,33 @@ func (h *Headscale) openDB() (*gorm.DB, error) { return db, nil } -// getValue returns the value for the given key in KV +// getValue returns the value for the given key in KV. func (h *Headscale) getValue(key string) (string, error) { var row KV - if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) { - return "", errors.New("not found") + if result := h.db.First(&row, "key = ?", key); errors.Is( + result.Error, + gorm.ErrRecordNotFound, + ) { + return "", errValueNotFound } + return row.Value, nil } -// setValue sets value for the given key in KV +// setValue sets value for the given key in KV. func (h *Headscale) setValue(key string, value string) error { - kv := KV{ + keyValue := KV{ Key: key, Value: value, } - _, err := h.getValue(key) - if err == nil { - h.db.Model(&kv).Where("key = ?", key).Update("value", value) + if _, err := h.getValue(key); err == nil { + h.db.Model(&keyValue).Where("key = ?", key).Update("value", value) + return nil } - h.db.Create(kv) + h.db.Create(keyValue) + return nil } diff --git a/derp-example.yaml b/derp-example.yaml index bbf7cc8..a9901be 100644 --- a/derp-example.yaml +++ b/derp-example.yaml @@ -1,15 +1,15 @@ # If you plan to somehow use headscale, please deploy your own DERP infra: https://tailscale.com/kb/1118/custom-derp-servers/ -regions: +regions: 900: regionid: 900 regioncode: custom regionname: My Region nodes: - - name: 1a - regionid: 1 - hostname: myderp.mydomain.no - ipv4: 123.123.123.123 - ipv6: "2604:a880:400:d1::828:b001" - stunport: 0 - stunonly: false - derptestport: 0 + - name: 1a + regionid: 1 + hostname: myderp.mydomain.no + ipv4: 123.123.123.123 + ipv6: "2604:a880:400:d1::828:b001" + stunport: 0 + stunonly: false + derptestport: 0 diff --git a/derp.go b/derp.go index 7f65832..63e448d 100644 --- a/derp.go +++ b/derp.go @@ -1,6 +1,7 @@ package headscale import ( + "context" "encoding/json" "io" "io/ioutil" @@ -10,9 +11,7 @@ import ( "time" "github.com/rs/zerolog/log" - "gopkg.in/yaml.v2" - "tailscale.com/tailcfg" ) @@ -28,14 +27,24 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) { return nil, err } err = yaml.Unmarshal(b, &derpMap) + return &derpMap, err } func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { - client := http.Client{ - Timeout: 10 * time.Second, + ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil) + if err != nil { + return nil, err } - resp, err := client.Get(addr.String()) + + client := http.Client{ + Timeout: HTTPReadTimeout, + } + + resp, err := client.Do(req) if err != nil { return nil, err } @@ -48,6 +57,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { var derpMap tailcfg.DERPMap err = json.Unmarshal(body, &derpMap) + return &derpMap, err } @@ -55,7 +65,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { // DERPMap, it will _only_ look at the Regions, an integer. // If a region exists in two of the given DERPMaps, the region // form the _last_ DERPMap will be preserved. -// An empty DERPMap list will result in a DERPMap with no regions +// An empty DERPMap list will result in a DERPMap with no regions. func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap { result := tailcfg.DERPMap{ OmitDefaultRegions: false, @@ -86,6 +96,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap { Str("path", path). Err(err). Msg("Could not load DERP map from path") + break } @@ -104,6 +115,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap { Str("url", addr.String()). Err(err). Msg("Could not load DERP map from path") + break } diff --git a/dns.go b/dns.go index c7ca32a..af6f989 100644 --- a/dns.go +++ b/dns.go @@ -10,6 +10,10 @@ import ( "tailscale.com/util/dnsname" ) +const ( + ByteSize = 8 +) + // generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`. // This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS // server (listening in 100.100.100.100 udp/53) should be used for. @@ -30,7 +34,9 @@ import ( // From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). // This allows us to then calculate the subnets included in the subsequent class block and generate the entries. -func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) { +func generateMagicDNSRootDomains( + ipPrefix netaddr.IPPrefix, +) []dnsname.FQDN { // TODO(juanfont): we are not handing out IPv6 addresses yet // and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network) ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.") @@ -41,15 +47,15 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ( maskBits, _ := netRange.Mask.Size() // lastOctet is the last IP byte covered by the mask - lastOctet := maskBits / 8 + lastOctet := maskBits / ByteSize // wildcardBits is the number of bits not under the mask in the lastOctet - wildcardBits := 8 - maskBits%8 + wildcardBits := ByteSize - maskBits%ByteSize // min is the value in the lastOctet byte of the IP // max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1 min := uint(netRange.IP[lastOctet]) - max := uint((min + 1< config/private.key cp config.yaml.[sqlite|postgres].example config/config.yaml - + cp derp-example.yaml config/derp.yaml ``` @@ -81,16 +85,19 @@ -p 127.0.0.1:8080:8080 \ headscale/headscale:x.x.x headscale serve ``` + ## Nodes configuration If you used tailscale.com before in your nodes, make sure you clear the tailscaled data folder - ```shell - systemctl stop tailscaled - rm -fr /var/lib/tailscale - systemctl start tailscaled - ``` +```shell +systemctl stop tailscaled +rm -fr /var/lib/tailscale +systemctl start tailscaled +``` + ### Adding node based on MACHINEKEY + 1. Add your first machine ```shell diff --git a/grpcv1.go b/grpcv1.go index 998f0c0..40419c3 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -155,7 +155,10 @@ func (api headscaleV1APIServer) RegisterMachine( ctx context.Context, request *v1.RegisterMachineRequest, ) (*v1.RegisterMachineResponse, error) { - log.Trace().Str("namespace", request.GetNamespace()).Str("machine_key", request.GetKey()).Msg("Registering machine") + log.Trace(). + Str("namespace", request.GetNamespace()). + Str("machine_key", request.GetKey()). + Msg("Registering machine") machine, err := api.h.RegisterMachine( request.GetKey(), request.GetNamespace(), @@ -208,7 +211,9 @@ func (api headscaleV1APIServer) ListMachines( return nil, err } - sharedMachines, err := api.h.ListSharedMachinesInNamespace(request.GetNamespace()) + sharedMachines, err := api.h.ListSharedMachinesInNamespace( + request.GetNamespace(), + ) if err != nil { return nil, err } @@ -333,12 +338,16 @@ func (api headscaleV1APIServer) DebugCreateMachine( return nil, err } - routes, err := stringToIpPrefix(request.GetRoutes()) + routes, err := stringToIPPrefix(request.GetRoutes()) if err != nil { return nil, err } - log.Trace().Caller().Interface("route-prefix", routes).Interface("route-str", request.GetRoutes()).Msg("") + log.Trace(). + Caller(). + Interface("route-prefix", routes). + Interface("route-str", request.GetRoutes()). + Msg("") hostinfo := tailcfg.Hostinfo{ RoutableIPs: routes, diff --git a/integration_cli_test.go b/integration_cli_test.go index e1b1672..898e2cd 100644 --- a/integration_cli_test.go +++ b/integration_cli_test.go @@ -88,6 +88,7 @@ func (s *IntegrationCLITestSuite) SetupTest() { if resp.StatusCode != http.StatusOK { return fmt.Errorf("status code not OK") } + return nil }); err != nil { // TODO(kradalby): If we cannot access headscale, or any other fatal error during @@ -109,7 +110,10 @@ func (s *IntegrationCLITestSuite) TearDownTest() { } } -func (s *IntegrationCLITestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) { +func (s *IntegrationCLITestSuite) HandleStats( + suiteName string, + stats *suite.SuiteInformation, +) { s.stats = stats } @@ -144,7 +148,6 @@ func (s *IntegrationCLITestSuite) TestNamespaceCommand() { namespaces := make([]*v1.Namespace, len(names)) for index, namespaceName := range names { - namespace, err := s.createNamespace(namespaceName) assert.Nil(s.T(), err) @@ -298,11 +301,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() { assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now())) assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now())) - assert.True(s.T(), listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) - assert.True(s.T(), listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) - assert.True(s.T(), listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) - assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) - assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26))) + assert.True( + s.T(), + listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), + ) + assert.True( + s.T(), + listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), + ) + assert.True( + s.T(), + listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), + ) + assert.True( + s.T(), + listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), + ) + assert.True( + s.T(), + listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), + ) // Expire three keys for i := 0; i < 3; i++ { @@ -341,11 +359,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() { err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys) assert.Nil(s.T(), err) - assert.True(s.T(), listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now())) - assert.True(s.T(), listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now())) - assert.True(s.T(), listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now())) - assert.True(s.T(), listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now())) - assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now())) + assert.True( + s.T(), + listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()), + ) + assert.True( + s.T(), + listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()), + ) + assert.True( + s.T(), + listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()), + ) + assert.True( + s.T(), + listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()), + ) + assert.True( + s.T(), + listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()), + ) } func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() { @@ -689,7 +722,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { assert.Nil(s.T(), err) var listOnlySharedMachineNamespace []v1.Machine - err = json.Unmarshal([]byte(listOnlySharedMachineNamespaceResult), &listOnlySharedMachineNamespace) + err = json.Unmarshal( + []byte(listOnlySharedMachineNamespaceResult), + &listOnlySharedMachineNamespace, + ) assert.Nil(s.T(), err) assert.Len(s.T(), listOnlySharedMachineNamespace, 2) @@ -738,7 +774,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { assert.Nil(s.T(), err) var listOnlyMachineNamespaceAfterDelete []v1.Machine - err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterDeleteResult), &listOnlyMachineNamespaceAfterDelete) + err = json.Unmarshal( + []byte(listOnlyMachineNamespaceAfterDeleteResult), + &listOnlyMachineNamespaceAfterDelete, + ) assert.Nil(s.T(), err) assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4) @@ -789,7 +828,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { assert.Nil(s.T(), err) var listOnlyMachineNamespaceAfterShare []v1.Machine - err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterShareResult), &listOnlyMachineNamespaceAfterShare) + err = json.Unmarshal( + []byte(listOnlyMachineNamespaceAfterShareResult), + &listOnlyMachineNamespaceAfterShare, + ) assert.Nil(s.T(), err) assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5) @@ -846,7 +888,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() { assert.Nil(s.T(), err) var listOnlyMachineNamespaceAfterUnshare []v1.Machine - err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterUnshareResult), &listOnlyMachineNamespaceAfterUnshare) + err = json.Unmarshal( + []byte(listOnlyMachineNamespaceAfterUnshareResult), + &listOnlyMachineNamespaceAfterUnshare, + ) assert.Nil(s.T(), err) assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4) @@ -1010,5 +1055,9 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() { ) assert.Nil(s.T(), err) - assert.Contains(s.T(), string(failEnableNonAdvertisedRoute), "route (route-machine) is not available on node") + assert.Contains( + s.T(), + string(failEnableNonAdvertisedRoute), + "route (route-machine) is not available on node", + ) } diff --git a/integration_common_test.go b/integration_common_test.go index 71d4866..31bae51 100644 --- a/integration_common_test.go +++ b/integration_common_test.go @@ -12,12 +12,18 @@ import ( "github.com/ory/dockertest/v3/docker" ) -func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (string, error) { +const DOCKER_EXECUTE_TIMEOUT = 10 * time.Second + +func ExecuteCommand( + resource *dockertest.Resource, + cmd []string, + env []string, +) (string, error) { var stdout bytes.Buffer var stderr bytes.Buffer // TODO(kradalby): Make configurable - timeout := 10 * time.Second + timeout := DOCKER_EXECUTE_TIMEOUT type result struct { exitCode int @@ -51,11 +57,13 @@ func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) ( fmt.Println("Command: ", cmd) fmt.Println("stdout: ", stdout.String()) fmt.Println("stderr: ", stderr.String()) + return "", fmt.Errorf("command failed with: %s", stderr.String()) } return stdout.String(), nil case <-time.After(timeout): + return "", fmt.Errorf("command timed out after %s", timeout) } } diff --git a/integration_test.go b/integration_test.go index c4f6bb4..a9125a0 100644 --- a/integration_test.go +++ b/integration_test.go @@ -23,10 +23,9 @@ import ( "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "inet.af/netaddr" "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn/ipnstate" - - "inet.af/netaddr" ) var tailscaleVersions = []string{"1.16.2", "1.14.3", "1.12.3"} @@ -89,7 +88,10 @@ func TestIntegrationTestSuite(t *testing.T) { } } -func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath string) error { +func (s *IntegrationTestSuite) saveLog( + resource *dockertest.Resource, + basePath string, +) error { err := os.MkdirAll(basePath, os.ModePerm) if err != nil { return err @@ -118,12 +120,20 @@ func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath s fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath) - err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0o644) + err = ioutil.WriteFile( + path.Join(basePath, resource.Container.Name+".stdout.log"), + []byte(stdout.String()), + 0o644, + ) if err != nil { return err } - err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0o644) + err = ioutil.WriteFile( + path.Join(basePath, resource.Container.Name+".stderr.log"), + []byte(stdout.String()), + 0o644, + ) if err != nil { return err } @@ -144,24 +154,38 @@ func (s *IntegrationTestSuite) tailscaleContainer( }, }, } - hostname := fmt.Sprintf("%s-tailscale-%s-%s", namespace, strings.Replace(version, ".", "-", -1), identifier) + hostname := fmt.Sprintf( + "%s-tailscale-%s-%s", + namespace, + strings.Replace(version, ".", "-", -1), + identifier, + ) tailscaleOptions := &dockertest.RunOptions{ Name: hostname, Networks: []*dockertest.Network{&s.network}, - Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"}, + Cmd: []string{ + "tailscaled", + "--tun=userspace-networking", + "--socks5-server=localhost:1055", + }, } - pts, err := s.pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, DockerRestartPolicy) + pts, err := s.pool.BuildAndRunWithBuildOptions( + tailscaleBuildOptions, + tailscaleOptions, + DockerRestartPolicy, + ) if err != nil { log.Fatalf("Could not start resource: %s", err) } fmt.Printf("Created %s container\n", hostname) + return hostname, pts } func (s *IntegrationTestSuite) SetupSuite() { var err error - h = Headscale{ + app = Headscale{ dbType: "sqlite3", dbString: "integration_test_db.sqlite3", } @@ -210,7 +234,11 @@ func (s *IntegrationTestSuite) SetupSuite() { for i := 0; i < scales.count; i++ { version := tailscaleVersions[i%len(tailscaleVersions)] - hostname, container := s.tailscaleContainer(namespace, fmt.Sprint(i), version) + hostname, container := s.tailscaleContainer( + namespace, + fmt.Sprint(i), + version, + ) scales.tailscales[hostname] = *container } } @@ -220,13 +248,16 @@ func (s *IntegrationTestSuite) SetupSuite() { if err := s.pool.Retry(func() error { url := fmt.Sprintf("http://%s/health", hostEndpoint) + resp, err := http.Get(url) if err != nil { return err } + if resp.StatusCode != http.StatusOK { return fmt.Errorf("status code not OK") } + return nil }); err != nil { // TODO(kradalby): If we cannot access headscale, or any other fatal error during @@ -273,7 +304,10 @@ func (s *IntegrationTestSuite) SetupSuite() { headscaleEndpoint := "http://headscale:8080" - fmt.Printf("Joining tailscale containers to headscale at %s\n", headscaleEndpoint) + fmt.Printf( + "Joining tailscale containers to headscale at %s\n", + headscaleEndpoint, + ) for hostname, tailscale := range scales.tailscales { command := []string{ "tailscale", @@ -307,7 +341,10 @@ func (s *IntegrationTestSuite) SetupSuite() { func (s *IntegrationTestSuite) TearDownSuite() { } -func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) { +func (s *IntegrationTestSuite) HandleStats( + suiteName string, + stats *suite.SuiteInformation, +) { s.stats = stats } @@ -427,7 +464,13 @@ func (s *IntegrationTestSuite) TestPingAllPeers() { ip.String(), } - fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) + fmt.Printf( + "Pinging from %s (%s) to %s (%s)\n", + hostname, + ips[hostname], + peername, + ip, + ) result, err := ExecuteCommand( &tailscale, command, @@ -449,7 +492,15 @@ func (s *IntegrationTestSuite) TestSharedNodes() { result, err := ExecuteCommand( &s.headscale, - []string{"headscale", "nodes", "list", "--output", "json", "--namespace", "shared"}, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + "--namespace", + "shared", + }, []string{}, ) assert.Nil(s.T(), err) @@ -459,7 +510,6 @@ func (s *IntegrationTestSuite) TestSharedNodes() { assert.Nil(s.T(), err) for _, machine := range machineList { - result, err := ExecuteCommand( &s.headscale, []string{ @@ -520,7 +570,13 @@ func (s *IntegrationTestSuite) TestSharedNodes() { ip.String(), } - fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip) + fmt.Printf( + "Pinging from %s (%s) to %s (%s)\n", + hostname, + mainIps[hostname], + peername, + ip, + ) result, err := ExecuteCommand( &tailscale, command, @@ -553,7 +609,6 @@ func (s *IntegrationTestSuite) TestTailDrop() { for peername, ip := range ips { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { if peername != hostname { - // Under normal circumstances, we should be able to send a file // using `tailscale file cp` - but not in userspace networking mode // So curl! @@ -578,9 +633,19 @@ func (s *IntegrationTestSuite) TestTailDrop() { "PUT", "--upload-file", fmt.Sprintf("/tmp/file_from_%s", hostname), - fmt.Sprintf("%s/v0/put/file_from_%s", peerAPI, hostname), + fmt.Sprintf( + "%s/v0/put/file_from_%s", + peerAPI, + hostname, + ), } - fmt.Printf("Sending file from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) + fmt.Printf( + "Sending file from %s (%s) to %s (%s)\n", + hostname, + ips[hostname], + peername, + ip, + ) _, err = ExecuteCommand( &tailscale, command, @@ -621,7 +686,13 @@ func (s *IntegrationTestSuite) TestTailDrop() { "ls", fmt.Sprintf("/tmp/file_from_%s", peername), } - fmt.Printf("Checking file in %s (%s) from %s (%s)\n", hostname, ips[hostname], peername, ip) + fmt.Printf( + "Checking file in %s (%s) from %s (%s)\n", + hostname, + ips[hostname], + peername, + ip, + ) result, err := ExecuteCommand( &tailscale, command, @@ -629,7 +700,11 @@ func (s *IntegrationTestSuite) TestTailDrop() { ) assert.Nil(t, err) fmt.Printf("Result for %s: %s\n", peername, result) - assert.Equal(t, result, fmt.Sprintf("/tmp/file_from_%s\n", peername)) + assert.Equal( + t, + result, + fmt.Sprintf("/tmp/file_from_%s\n", peername), + ) } }) } @@ -696,10 +771,13 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e ips[hostname] = ip } + return ips, nil } -func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]string, error) { +func getAPIURLs( + tailscales map[string]dockertest.Resource, +) (map[netaddr.IP]string, error) { fts := make(map[netaddr.IP]string) for _, tailscale := range tailscales { command := []string{ @@ -733,5 +811,6 @@ func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]strin } } } + return fts, nil } diff --git a/k8s/README.md b/k8s/README.md index 45574b4..78e9ef2 100644 --- a/k8s/README.md +++ b/k8s/README.md @@ -24,6 +24,7 @@ Configure DERP servers by editing `base/site/derp.yaml` if needed. You'll somehow need to get `headscale:latest` into your cluster image registry. An easy way to do this with k3s: + - Reconfigure k3s to use docker instead of containerd (`k3s server --docker`) - `docker build -t headscale:latest ..` from here @@ -61,7 +62,7 @@ Use the wrapper script to remotely operate headscale to perform administrative tasks like creating namespaces, authkeys, etc. ``` -[c@nix-slate:~/Projects/headscale/k8s]$ ./headscale.bash +[c@nix-slate:~/Projects/headscale/k8s]$ ./headscale.bash headscale is an open source implementation of the Tailscale control server diff --git a/k8s/base/ingress.yaml b/k8s/base/ingress.yaml index a279bc1..51da342 100644 --- a/k8s/base/ingress.yaml +++ b/k8s/base/ingress.yaml @@ -6,13 +6,13 @@ metadata: kubernetes.io/ingress.class: traefik spec: rules: - - host: $(PUBLIC_HOSTNAME) - http: - paths: - - backend: - service: - name: headscale - port: - number: 8080 - path: / - pathType: Prefix + - host: $(PUBLIC_HOSTNAME) + http: + paths: + - backend: + service: + name: headscale + port: + number: 8080 + path: / + pathType: Prefix diff --git a/k8s/base/kustomization.yaml b/k8s/base/kustomization.yaml index 54d66e5..93278f7 100644 --- a/k8s/base/kustomization.yaml +++ b/k8s/base/kustomization.yaml @@ -1,42 +1,42 @@ namespace: headscale resources: -- configmap.yaml -- ingress.yaml -- service.yaml + - configmap.yaml + - ingress.yaml + - service.yaml generatorOptions: disableNameSuffixHash: true configMapGenerator: -- name: headscale-site - files: - - derp.yaml=site/derp.yaml - envs: - - site/public.env -- name: headscale-etc - literals: - - config.json={} + - name: headscale-site + files: + - derp.yaml=site/derp.yaml + envs: + - site/public.env + - name: headscale-etc + literals: + - config.json={} secretGenerator: -- name: headscale - files: - - secrets/private-key + - name: headscale + files: + - secrets/private-key vars: -- name: PUBLIC_PROTO - objRef: - kind: ConfigMap - name: headscale-site - apiVersion: v1 - fieldRef: - fieldPath: data.public-proto -- name: PUBLIC_HOSTNAME - objRef: - kind: ConfigMap - name: headscale-site - apiVersion: v1 - fieldRef: - fieldPath: data.public-hostname -- name: CONTACT_EMAIL - objRef: - kind: ConfigMap - name: headscale-site - apiVersion: v1 - fieldRef: - fieldPath: data.contact-email + - name: PUBLIC_PROTO + objRef: + kind: ConfigMap + name: headscale-site + apiVersion: v1 + fieldRef: + fieldPath: data.public-proto + - name: PUBLIC_HOSTNAME + objRef: + kind: ConfigMap + name: headscale-site + apiVersion: v1 + fieldRef: + fieldPath: data.public-hostname + - name: CONTACT_EMAIL + objRef: + kind: ConfigMap + name: headscale-site + apiVersion: v1 + fieldRef: + fieldPath: data.contact-email diff --git a/k8s/base/service.yaml b/k8s/base/service.yaml index 7fdf738..39e6725 100644 --- a/k8s/base/service.yaml +++ b/k8s/base/service.yaml @@ -8,6 +8,6 @@ spec: selector: app: headscale ports: - - name: http - targetPort: http - port: 8080 + - name: http + targetPort: http + port: 8080 diff --git a/k8s/postgres/deployment.yaml b/k8s/postgres/deployment.yaml index dd45d05..661d87e 100644 --- a/k8s/postgres/deployment.yaml +++ b/k8s/postgres/deployment.yaml @@ -13,66 +13,66 @@ spec: app: headscale spec: containers: - - name: headscale - image: "headscale:latest" - imagePullPolicy: IfNotPresent - command: ["/go/bin/headscale", "serve"] - env: - - name: SERVER_URL - value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME) - - name: LISTEN_ADDR - valueFrom: - configMapKeyRef: - name: headscale-config - key: listen_addr - - name: PRIVATE_KEY_PATH - value: /vol/secret/private-key - - name: DERP_MAP_PATH - value: /vol/config/derp.yaml - - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT - valueFrom: - configMapKeyRef: - name: headscale-config - key: ephemeral_node_inactivity_timeout - - name: DB_TYPE - value: postgres - - name: DB_HOST - value: postgres.headscale.svc.cluster.local - - name: DB_PORT - value: "5432" - - name: DB_USER - value: headscale - - name: DB_PASS - valueFrom: - secretKeyRef: - name: postgresql - key: password - - name: DB_NAME - value: headscale - ports: - - name: http - protocol: TCP - containerPort: 8080 - livenessProbe: - tcpSocket: - port: http - initialDelaySeconds: 30 - timeoutSeconds: 5 - periodSeconds: 15 - volumeMounts: - - name: config - mountPath: /vol/config - - name: secret - mountPath: /vol/secret - - name: etc - mountPath: /etc/headscale + - name: headscale + image: "headscale:latest" + imagePullPolicy: IfNotPresent + command: ["/go/bin/headscale", "serve"] + env: + - name: SERVER_URL + value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME) + - name: LISTEN_ADDR + valueFrom: + configMapKeyRef: + name: headscale-config + key: listen_addr + - name: PRIVATE_KEY_PATH + value: /vol/secret/private-key + - name: DERP_MAP_PATH + value: /vol/config/derp.yaml + - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT + valueFrom: + configMapKeyRef: + name: headscale-config + key: ephemeral_node_inactivity_timeout + - name: DB_TYPE + value: postgres + - name: DB_HOST + value: postgres.headscale.svc.cluster.local + - name: DB_PORT + value: "5432" + - name: DB_USER + value: headscale + - name: DB_PASS + valueFrom: + secretKeyRef: + name: postgresql + key: password + - name: DB_NAME + value: headscale + ports: + - name: http + protocol: TCP + containerPort: 8080 + livenessProbe: + tcpSocket: + port: http + initialDelaySeconds: 30 + timeoutSeconds: 5 + periodSeconds: 15 + volumeMounts: + - name: config + mountPath: /vol/config + - name: secret + mountPath: /vol/secret + - name: etc + mountPath: /etc/headscale volumes: - - name: config - configMap: - name: headscale-site - - name: etc - configMap: - name: headscale-etc - - name: secret - secret: - secretName: headscale + - name: config + configMap: + name: headscale-site + - name: etc + configMap: + name: headscale-etc + - name: secret + secret: + secretName: headscale diff --git a/k8s/postgres/kustomization.yaml b/k8s/postgres/kustomization.yaml index 8bd6c40..e732e3b 100644 --- a/k8s/postgres/kustomization.yaml +++ b/k8s/postgres/kustomization.yaml @@ -1,13 +1,13 @@ namespace: headscale bases: -- ../base + - ../base resources: -- deployment.yaml -- postgres-service.yaml -- postgres-statefulset.yaml + - deployment.yaml + - postgres-service.yaml + - postgres-statefulset.yaml generatorOptions: disableNameSuffixHash: true secretGenerator: -- name: postgresql - files: - - secrets/password + - name: postgresql + files: + - secrets/password diff --git a/k8s/postgres/postgres-service.yaml b/k8s/postgres/postgres-service.yaml index e2f486c..6252e7f 100644 --- a/k8s/postgres/postgres-service.yaml +++ b/k8s/postgres/postgres-service.yaml @@ -8,6 +8,6 @@ spec: selector: app: postgres ports: - - name: postgres - targetPort: postgres - port: 5432 + - name: postgres + targetPort: postgres + port: 5432 diff --git a/k8s/postgres/postgres-statefulset.yaml b/k8s/postgres/postgres-statefulset.yaml index 25285c5..b81c9bf 100644 --- a/k8s/postgres/postgres-statefulset.yaml +++ b/k8s/postgres/postgres-statefulset.yaml @@ -14,36 +14,36 @@ spec: app: postgres spec: containers: - - name: postgres - image: "postgres:13" - imagePullPolicy: IfNotPresent - env: - - name: POSTGRES_PASSWORD - valueFrom: - secretKeyRef: - name: postgresql - key: password - - name: POSTGRES_USER - value: headscale - ports: - name: postgres - protocol: TCP - containerPort: 5432 - livenessProbe: - tcpSocket: - port: 5432 - initialDelaySeconds: 30 - timeoutSeconds: 5 - periodSeconds: 15 - volumeMounts: - - name: pgdata - mountPath: /var/lib/postgresql/data + image: "postgres:13" + imagePullPolicy: IfNotPresent + env: + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: postgresql + key: password + - name: POSTGRES_USER + value: headscale + ports: + - name: postgres + protocol: TCP + containerPort: 5432 + livenessProbe: + tcpSocket: + port: 5432 + initialDelaySeconds: 30 + timeoutSeconds: 5 + periodSeconds: 15 + volumeMounts: + - name: pgdata + mountPath: /var/lib/postgresql/data volumeClaimTemplates: - - metadata: - name: pgdata - spec: - storageClassName: local-path - accessModes: ["ReadWriteOnce"] - resources: - requests: - storage: 1Gi + - metadata: + name: pgdata + spec: + storageClassName: local-path + accessModes: ["ReadWriteOnce"] + resources: + requests: + storage: 1Gi diff --git a/k8s/production-tls/ingress-patch.yaml b/k8s/production-tls/ingress-patch.yaml index 387c736..9e6177f 100644 --- a/k8s/production-tls/ingress-patch.yaml +++ b/k8s/production-tls/ingress-patch.yaml @@ -6,6 +6,6 @@ metadata: traefik.ingress.kubernetes.io/router.tls: "true" spec: tls: - - hosts: - - $(PUBLIC_HOSTNAME) - secretName: production-cert + - hosts: + - $(PUBLIC_HOSTNAME) + secretName: production-cert diff --git a/k8s/production-tls/kustomization.yaml b/k8s/production-tls/kustomization.yaml index f57cb54..d3147f5 100644 --- a/k8s/production-tls/kustomization.yaml +++ b/k8s/production-tls/kustomization.yaml @@ -1,9 +1,9 @@ namespace: headscale bases: -- ../base + - ../base resources: -- production-issuer.yaml + - production-issuer.yaml patches: -- path: ingress-patch.yaml - target: - kind: Ingress + - path: ingress-patch.yaml + target: + kind: Ingress diff --git a/k8s/production-tls/production-issuer.yaml b/k8s/production-tls/production-issuer.yaml index 7ae9131..f436090 100644 --- a/k8s/production-tls/production-issuer.yaml +++ b/k8s/production-tls/production-issuer.yaml @@ -11,6 +11,6 @@ spec: # Secret resource used to store the account's private key. name: letsencrypt-production-acc-key solvers: - - http01: - ingress: - class: traefik + - http01: + ingress: + class: traefik diff --git a/k8s/sqlite/kustomization.yaml b/k8s/sqlite/kustomization.yaml index 5be451c..ca79941 100644 --- a/k8s/sqlite/kustomization.yaml +++ b/k8s/sqlite/kustomization.yaml @@ -1,5 +1,5 @@ namespace: headscale bases: -- ../base + - ../base resources: -- statefulset.yaml + - statefulset.yaml diff --git a/k8s/sqlite/statefulset.yaml b/k8s/sqlite/statefulset.yaml index 9075e00..71077da 100644 --- a/k8s/sqlite/statefulset.yaml +++ b/k8s/sqlite/statefulset.yaml @@ -14,66 +14,66 @@ spec: app: headscale spec: containers: - - name: headscale - image: "headscale:latest" - imagePullPolicy: IfNotPresent - command: ["/go/bin/headscale", "serve"] - env: - - name: SERVER_URL - value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME) - - name: LISTEN_ADDR - valueFrom: - configMapKeyRef: - name: headscale-config - key: listen_addr - - name: PRIVATE_KEY_PATH - value: /vol/secret/private-key - - name: DERP_MAP_PATH - value: /vol/config/derp.yaml - - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT - valueFrom: - configMapKeyRef: - name: headscale-config - key: ephemeral_node_inactivity_timeout - - name: DB_TYPE - value: sqlite3 - - name: DB_PATH - value: /vol/data/db.sqlite - ports: - - name: http - protocol: TCP - containerPort: 8080 - livenessProbe: - tcpSocket: - port: http - initialDelaySeconds: 30 - timeoutSeconds: 5 - periodSeconds: 15 - volumeMounts: - - name: config - mountPath: /vol/config - - name: data - mountPath: /vol/data - - name: secret - mountPath: /vol/secret - - name: etc - mountPath: /etc/headscale + - name: headscale + image: "headscale:latest" + imagePullPolicy: IfNotPresent + command: ["/go/bin/headscale", "serve"] + env: + - name: SERVER_URL + value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME) + - name: LISTEN_ADDR + valueFrom: + configMapKeyRef: + name: headscale-config + key: listen_addr + - name: PRIVATE_KEY_PATH + value: /vol/secret/private-key + - name: DERP_MAP_PATH + value: /vol/config/derp.yaml + - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT + valueFrom: + configMapKeyRef: + name: headscale-config + key: ephemeral_node_inactivity_timeout + - name: DB_TYPE + value: sqlite3 + - name: DB_PATH + value: /vol/data/db.sqlite + ports: + - name: http + protocol: TCP + containerPort: 8080 + livenessProbe: + tcpSocket: + port: http + initialDelaySeconds: 30 + timeoutSeconds: 5 + periodSeconds: 15 + volumeMounts: + - name: config + mountPath: /vol/config + - name: data + mountPath: /vol/data + - name: secret + mountPath: /vol/secret + - name: etc + mountPath: /etc/headscale volumes: - - name: config - configMap: - name: headscale-site - - name: etc - configMap: - name: headscale-etc - - name: secret - secret: - secretName: headscale + - name: config + configMap: + name: headscale-site + - name: etc + configMap: + name: headscale-etc + - name: secret + secret: + secretName: headscale volumeClaimTemplates: - - metadata: - name: data - spec: - storageClassName: local-path - accessModes: ["ReadWriteOnce"] - resources: - requests: - storage: 1Gi + - metadata: + name: data + spec: + storageClassName: local-path + accessModes: ["ReadWriteOnce"] + resources: + requests: + storage: 1Gi diff --git a/k8s/staging-tls/ingress-patch.yaml b/k8s/staging-tls/ingress-patch.yaml index f97974b..5a1daf0 100644 --- a/k8s/staging-tls/ingress-patch.yaml +++ b/k8s/staging-tls/ingress-patch.yaml @@ -6,6 +6,6 @@ metadata: traefik.ingress.kubernetes.io/router.tls: "true" spec: tls: - - hosts: - - $(PUBLIC_HOSTNAME) - secretName: staging-cert + - hosts: + - $(PUBLIC_HOSTNAME) + secretName: staging-cert diff --git a/k8s/staging-tls/kustomization.yaml b/k8s/staging-tls/kustomization.yaml index 931f27d..0900c58 100644 --- a/k8s/staging-tls/kustomization.yaml +++ b/k8s/staging-tls/kustomization.yaml @@ -1,9 +1,9 @@ namespace: headscale bases: -- ../base + - ../base resources: -- staging-issuer.yaml + - staging-issuer.yaml patches: -- path: ingress-patch.yaml - target: - kind: Ingress + - path: ingress-patch.yaml + target: + kind: Ingress diff --git a/k8s/staging-tls/staging-issuer.yaml b/k8s/staging-tls/staging-issuer.yaml index 95325f6..cf29041 100644 --- a/k8s/staging-tls/staging-issuer.yaml +++ b/k8s/staging-tls/staging-issuer.yaml @@ -11,6 +11,6 @@ spec: # Secret resource used to store the account's private key. name: letsencrypt-staging-acc-key solvers: - - http01: - ingress: - class: traefik + - http01: + ingress: + class: traefik diff --git a/machine.go b/machine.go index 557ab5b..dcd7970 100644 --- a/machine.go +++ b/machine.go @@ -10,10 +10,9 @@ import ( "time" "github.com/fatih/set" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" - - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "gorm.io/datatypes" "gorm.io/gorm" "inet.af/netaddr" @@ -21,7 +20,13 @@ import ( "tailscale.com/types/wgkey" ) -// Machine is a Headscale client +const ( + errMachineNotFound = Error("machine not found") + errMachineAlreadyRegistered = Error("machine already registered") + errMachineRouteIsNotAvailable = Error("route is not available on machine") +) + +// Machine is a Headscale client. type Machine struct { ID uint64 `gorm:"primary_key"` MachineKey string `gorm:"type:varchar(64);unique_index"` @@ -56,53 +61,58 @@ type ( MachinesP []*Machine ) -// For the time being this method is rather naive -func (m Machine) isAlreadyRegistered() bool { - return m.Registered +// For the time being this method is rather naive. +func (machine Machine) isAlreadyRegistered() bool { + return machine.Registered } -// isExpired returns whether the machine registration has expired -func (m Machine) isExpired() bool { - return time.Now().UTC().After(*m.Expiry) +// isExpired returns whether the machine registration has expired. +func (machine Machine) isExpired() bool { + return time.Now().UTC().After(*machine.Expiry) } // If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration, // or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause // a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the // expiry time. -func (h *Headscale) updateMachineExpiry(m *Machine) { - if m.isExpired() { +func (h *Headscale) updateMachineExpiry(machine *Machine) { + if machine.isExpired() { now := time.Now().UTC() - maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry - defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry + maxExpiry := now.Add( + h.cfg.MaxMachineRegistrationDuration, + ) // calculate the maximum expiry + defaultExpiry := now.Add( + h.cfg.DefaultMachineRegistrationDuration, + ) // calculate the default expiry // clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied - if maxExpiry.Before(*m.RequestedExpiry) { + if maxExpiry.Before(*machine.RequestedExpiry) { log.Debug(). Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration) - m.Expiry = &maxExpiry - } else if m.RequestedExpiry.IsZero() { + machine.Expiry = &maxExpiry + } else if machine.RequestedExpiry.IsZero() { log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration) - m.Expiry = &defaultExpiry + machine.Expiry = &defaultExpiry } else { - log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry) - m.Expiry = m.RequestedExpiry + log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry) + machine.Expiry = machine.RequestedExpiry } - h.db.Save(&m) + h.db.Save(&machine) } } -func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { +func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Finding direct peers") machines := Machines{} if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered", - m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { + machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") + return Machines{}, err } @@ -110,21 +120,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("Found direct machines: %s", machines.String()) + return machines, nil } -// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for -func (h *Headscale) getShared(m *Machine) (Machines, error) { +// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for. +func (h *Headscale) getShared(machine *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Finding shared peers") sharedMachines := []SharedMachine{} if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?", - m.NamespaceID).Find(&sharedMachines).Error; err != nil { + machine.NamespaceID).Find(&sharedMachines).Error; err != nil { return Machines{}, err } @@ -137,27 +148,30 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("Found shared peers: %s", peers.String()) + return peers, nil } -// getSharedTo fetches the machines of the namespaces this machine is shared in -func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { +// getSharedTo fetches the machines of the namespaces this machine is shared in. +func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Finding peers in namespaces this machine is shared with") sharedMachines := []SharedMachine{} if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?", - m.ID).Find(&sharedMachines).Error; err != nil { + machine.ID).Find(&sharedMachines).Error; err != nil { return Machines{}, err } peers := make(Machines, 0) for _, sharedMachine := range sharedMachines { - namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name) + namespaceMachines, err := h.ListMachinesInNamespace( + sharedMachine.Namespace.Name, + ) if err != nil { return Machines{}, err } @@ -168,36 +182,40 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("Found peers we are shared with: %s", peers.String()) + return peers, nil } -func (h *Headscale) getPeers(m *Machine) (Machines, error) { - direct, err := h.getDirectPeers(m) +func (h *Headscale) getPeers(machine *Machine) (Machines, error) { + direct, err := h.getDirectPeers(machine) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot fetch peers") + return Machines{}, err } - shared, err := h.getShared(m) + shared, err := h.getShared(machine) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot fetch peers") + return Machines{}, err } - sharedTo, err := h.getSharedTo(m) + sharedTo, err := h.getSharedTo(machine) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot fetch peers") + return Machines{}, err } @@ -208,7 +226,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("Found total peers: %s", peers.String()) return peers, nil @@ -219,10 +237,11 @@ func (h *Headscale) ListMachines() ([]Machine, error) { if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil { return nil, err } + return machines, nil } -// GetMachine finds a Machine by name and namespace and returns the Machine struct +// GetMachine finds a Machine by name and namespace and returns the Machine struct. func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) { machines, err := h.ListMachinesInNamespace(namespace) if err != nil { @@ -234,73 +253,77 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) return &m, nil } } - return nil, fmt.Errorf("machine not found") + + return nil, errMachineNotFound } -// GetMachineByID finds a Machine by ID and returns the Machine struct +// GetMachineByID finds a Machine by ID and returns the Machine struct. func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) { m := Machine{} if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil { return nil, result.Error } + return &m, nil } -// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct -func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) { +// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. +func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) { m := Machine{} - if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil { + if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil { return nil, result.Error } + return &m, nil } // UpdateMachine takes a Machine struct pointer (typically already loaded from database // and updates it with the latest data from the database. -func (h *Headscale) UpdateMachine(m *Machine) error { - if result := h.db.Find(m).First(&m); result.Error != nil { +func (h *Headscale) UpdateMachine(machine *Machine) error { + if result := h.db.Find(machine).First(&machine); result.Error != nil { return result.Error } + return nil } -// DeleteMachine softs deletes a Machine from the database -func (h *Headscale) DeleteMachine(m *Machine) error { - err := h.RemoveSharedMachineFromAllNamespaces(m) - if err != nil && err != errorMachineNotShared { +// DeleteMachine softs deletes a Machine from the database. +func (h *Headscale) DeleteMachine(machine *Machine) error { + err := h.RemoveSharedMachineFromAllNamespaces(machine) + if err != nil && errors.Is(err, errMachineNotShared) { return err } - m.Registered = false - namespaceID := m.NamespaceID - h.db.Save(&m) // we mark it as unregistered, just in case - if err := h.db.Delete(&m).Error; err != nil { + machine.Registered = false + namespaceID := machine.NamespaceID + h.db.Save(&machine) // we mark it as unregistered, just in case + if err := h.db.Delete(&machine).Error; err != nil { return err } return h.RequestMapUpdates(namespaceID) } -// HardDeleteMachine hard deletes a Machine from the database -func (h *Headscale) HardDeleteMachine(m *Machine) error { - err := h.RemoveSharedMachineFromAllNamespaces(m) - if err != nil && err != errorMachineNotShared { +// HardDeleteMachine hard deletes a Machine from the database. +func (h *Headscale) HardDeleteMachine(machine *Machine) error { + err := h.RemoveSharedMachineFromAllNamespaces(machine) + if err != nil && errors.Is(err, errMachineNotShared) { return err } - namespaceID := m.NamespaceID - if err := h.db.Unscoped().Delete(&m).Error; err != nil { + namespaceID := machine.NamespaceID + if err := h.db.Unscoped().Delete(&machine).Error; err != nil { return err } return h.RequestMapUpdates(namespaceID) } -// GetHostInfo returns a Hostinfo struct for the machine -func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { +// GetHostInfo returns a Hostinfo struct for the machine. +func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { hostinfo := tailcfg.Hostinfo{} - if len(m.HostInfo) != 0 { - hi, err := m.HostInfo.MarshalJSON() + if len(machine.HostInfo) != 0 { + hi, err := machine.HostInfo.MarshalJSON() if err != nil { return nil, err } @@ -309,21 +332,21 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) { return nil, err } } + return &hostinfo, nil } -func (h *Headscale) isOutdated(m *Machine) bool { - err := h.UpdateMachine(m) - if err != nil { +func (h *Headscale) isOutdated(machine *Machine) bool { + if err := h.UpdateMachine(machine); err != nil { // It does not seem meaningful to propagate this error as the end result // will have to be that the machine has to be considered outdated. return true } - sharedMachines, _ := h.getShared(m) + sharedMachines, _ := h.getShared(machine) namespaceSet := set.New(set.ThreadSafe) - namespaceSet.Add(m.Namespace.Name) + namespaceSet.Add(machine.Namespace.Name) // Check if any of our shared namespaces has updates that we have // not propagated. @@ -333,27 +356,30 @@ func (h *Headscale) isOutdated(m *Machine) bool { namespaces := make([]string, namespaceSet.Size()) for index, namespace := range namespaceSet.List() { - namespaces[index] = namespace.(string) + if name, ok := namespace.(string); ok { + namespaces[index] = name + } } lastChange := h.getLastStateChange(namespaces...) log.Trace(). Caller(). - Str("machine", m.Name). - Time("last_successful_update", *m.LastSuccessfulUpdate). + Str("machine", machine.Name). + Time("last_successful_update", *machine.LastSuccessfulUpdate). Time("last_state_change", lastChange). - Msgf("Checking if %s is missing updates", m.Name) - return m.LastSuccessfulUpdate.Before(lastChange) + Msgf("Checking if %s is missing updates", machine.Name) + + return machine.LastSuccessfulUpdate.Before(lastChange) } -func (m Machine) String() string { - return m.Name +func (machine Machine) String() string { + return machine.Name } -func (ms Machines) String() string { - temp := make([]string, len(ms)) +func (machines Machines) String() string { + temp := make([]string, len(machines)) - for index, machine := range ms { + for index, machine := range machines { temp[index] = machine.Name } @@ -361,24 +387,24 @@ func (ms Machines) String() string { } // TODO(kradalby): Remove when we have generics... -func (ms MachinesP) String() string { - temp := make([]string, len(ms)) +func (machines MachinesP) String() string { + temp := make([]string, len(machines)) - for index, machine := range ms { + for index, machine := range machines { temp[index] = machine.Name } return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (ms Machines) toNodes( +func (machines Machines) toNodes( baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool, ) ([]*tailcfg.Node, error) { - nodes := make([]*tailcfg.Node, len(ms)) + nodes := make([]*tailcfg.Node, len(machines)) - for index, machine := range ms { + for index, machine := range machines { node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes) if err != nil { return nil, err @@ -391,20 +417,25 @@ func (ms Machines) toNodes( } // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes -// as per the expected behaviour in the official SaaS -func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) { - nKey, err := wgkey.ParseHex(m.NodeKey) +// as per the expected behaviour in the official SaaS. +func (machine Machine) toNode( + baseDomain string, + dnsConfig *tailcfg.DNSConfig, + includeRoutes bool, +) (*tailcfg.Node, error) { + nodeKey, err := wgkey.ParseHex(machine.NodeKey) if err != nil { return nil, err } - mKey, err := wgkey.ParseHex(m.MachineKey) + + machineKey, err := wgkey.ParseHex(machine.MachineKey) if err != nil { return nil, err } var discoKey tailcfg.DiscoKey - if m.DiscoKey != "" { - dKey, err := wgkey.ParseHex(m.DiscoKey) + if machine.DiscoKey != "" { + dKey, err := wgkey.ParseHex(machine.DiscoKey) if err != nil { return nil, err } @@ -414,23 +445,27 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include } addrs := []netaddr.IPPrefix{} - ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress)) + ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress)) if err != nil { log.Trace(). Caller(). - Str("ip", m.IPAddress). - Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress) + Str("ip", machine.IPAddress). + Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress) + return nil, err } addrs = append(addrs, ip) // missing the ipv6 ? allowedIPs := []netaddr.IPPrefix{} - allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients + allowedIPs = append( + allowedIPs, + ip, + ) // we append the node own IP, as it is required by the clients if includeRoutes { routesStr := []string{} - if len(m.EnabledRoutes) != 0 { - allwIps, err := m.EnabledRoutes.MarshalJSON() + if len(machine.EnabledRoutes) != 0 { + allwIps, err := machine.EnabledRoutes.MarshalJSON() if err != nil { return nil, err } @@ -450,8 +485,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include } endpoints := []string{} - if len(m.Endpoints) != 0 { - be, err := m.Endpoints.MarshalJSON() + if len(machine.Endpoints) != 0 { + be, err := machine.Endpoints.MarshalJSON() if err != nil { return nil, err } @@ -462,8 +497,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include } hostinfo := tailcfg.Hostinfo{} - if len(m.HostInfo) != 0 { - hi, err := m.HostInfo.MarshalJSON() + if len(machine.HostInfo) != 0 { + hi, err := machine.HostInfo.MarshalJSON() if err != nil { return nil, err } @@ -481,29 +516,34 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include } var keyExpiry time.Time - if m.Expiry != nil { - keyExpiry = *m.Expiry + if machine.Expiry != nil { + keyExpiry = *machine.Expiry } else { keyExpiry = time.Time{} } var hostname string if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS - hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain) + hostname = fmt.Sprintf( + "%s.%s.%s", + machine.Name, + machine.Namespace.Name, + baseDomain, + ) } else { - hostname = m.Name + hostname = machine.Name } - n := tailcfg.Node{ - ID: tailcfg.NodeID(m.ID), // this is the actual ID + node := tailcfg.Node{ + ID: tailcfg.NodeID(machine.ID), // this is the actual ID StableID: tailcfg.StableNodeID( - strconv.FormatUint(m.ID, 10), + strconv.FormatUint(machine.ID, Base10), ), // in headscale, unlike tailcontrol server, IDs are permanent Name: hostname, - User: tailcfg.UserID(m.NamespaceID), - Key: tailcfg.NodeKey(nKey), + User: tailcfg.UserID(machine.NamespaceID), + Key: tailcfg.NodeKey(nodeKey), KeyExpiry: keyExpiry, - Machine: tailcfg.MachineKey(mKey), + Machine: tailcfg.MachineKey(machineKey), DiscoKey: discoKey, Addresses: addrs, AllowedIPs: allowedIPs, @@ -511,81 +551,90 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include DERP: derp, Hostinfo: hostinfo, - Created: m.CreatedAt, - LastSeen: m.LastSeen, + Created: machine.CreatedAt, + LastSeen: machine.LastSeen, KeepAlive: true, - MachineAuthorized: m.Registered, + MachineAuthorized: machine.Registered, Capabilities: []string{tailcfg.CapabilityFileSharing}, } - return &n, nil + + return &node, nil } -func (m *Machine) toProto() *v1.Machine { - machine := &v1.Machine{ - Id: m.ID, - MachineKey: m.MachineKey, +func (machine *Machine) toProto() *v1.Machine { + machineProto := &v1.Machine{ + Id: machine.ID, + MachineKey: machine.MachineKey, - NodeKey: m.NodeKey, - DiscoKey: m.DiscoKey, - IpAddress: m.IPAddress, - Name: m.Name, - Namespace: m.Namespace.toProto(), + NodeKey: machine.NodeKey, + DiscoKey: machine.DiscoKey, + IpAddress: machine.IPAddress, + Name: machine.Name, + Namespace: machine.Namespace.toProto(), - Registered: m.Registered, + Registered: machine.Registered, // TODO(kradalby): Implement register method enum converter // RegisterMethod: , - CreatedAt: timestamppb.New(m.CreatedAt), + CreatedAt: timestamppb.New(machine.CreatedAt), } - if m.AuthKey != nil { - machine.PreAuthKey = m.AuthKey.toProto() + if machine.AuthKey != nil { + machineProto.PreAuthKey = machine.AuthKey.toProto() } - if m.LastSeen != nil { - machine.LastSeen = timestamppb.New(*m.LastSeen) + if machine.LastSeen != nil { + machineProto.LastSeen = timestamppb.New(*machine.LastSeen) } - if m.LastSuccessfulUpdate != nil { - machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate) + if machine.LastSuccessfulUpdate != nil { + machineProto.LastSuccessfulUpdate = timestamppb.New( + *machine.LastSuccessfulUpdate, + ) } - if m.Expiry != nil { - machine.Expiry = timestamppb.New(*m.Expiry) + if machine.Expiry != nil { + machineProto.Expiry = timestamppb.New(*machine.Expiry) } - return machine + return machineProto } -// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey -func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) { - ns, err := h.GetNamespace(namespace) +// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. +func (h *Headscale) RegisterMachine( + key string, + namespaceName string, +) (*Machine, error) { + namespace, err := h.GetNamespace(namespaceName) if err != nil { return nil, err } - mKey, err := wgkey.ParseHex(key) + machineKey, err := wgkey.ParseHex(key) if err != nil { return nil, err } - m := Machine{} - if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, errors.New("Machine not found") + machine := Machine{} + if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is( + result.Error, + gorm.ErrRecordNotFound, + ) { + return nil, errMachineNotFound } log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Attempting to register machine") - if m.isAlreadyRegistered() { - err := errors.New("Machine already registered") + if machine.isAlreadyRegistered() { + err := errMachineAlreadyRegistered log.Error(). Caller(). Err(err). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Attempting to register machine") return nil, err @@ -596,42 +645,44 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err log.Error(). Caller(). Err(err). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Could not find IP for the new machine") + return nil, err } log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Str("ip", ip.String()). Msg("Found IP for host") - m.IPAddress = ip.String() - m.NamespaceID = ns.ID - m.Registered = true - m.RegisterMethod = "cli" - h.db.Save(&m) + machine.IPAddress = ip.String() + machine.NamespaceID = namespace.ID + machine.Registered = true + machine.RegisterMethod = "cli" + h.db.Save(&machine) log.Trace(). Caller(). - Str("machine", m.Name). + Str("machine", machine.Name). Str("ip", ip.String()). Msg("Machine registered with the database") - return &m, nil + return &machine, nil } -func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { - hostInfo, err := m.GetHostInfo() +func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) { + hostInfo, err := machine.GetHostInfo() if err != nil { return nil, err } + return hostInfo.RoutableIPs, nil } -func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { - data, err := m.EnabledRoutes.MarshalJSON() +func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { + data, err := machine.EnabledRoutes.MarshalJSON() if err != nil { return nil, err } @@ -654,13 +705,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) { return routes, nil } -func (m *Machine) IsRoutesEnabled(routeStr string) bool { +func (machine *Machine) IsRoutesEnabled(routeStr string) bool { route, err := netaddr.ParseIPPrefix(routeStr) if err != nil { return false } - enabledRoutes, err := m.GetEnabledRoutes() + enabledRoutes, err := machine.GetEnabledRoutes() if err != nil { return false } @@ -670,12 +721,13 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool { return true } } + return false } // EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the // previous list of routes. -func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { +func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { newRoutes := make([]netaddr.IPPrefix, len(routeStrs)) for index, routeStr := range routeStrs { route, err := netaddr.ParseIPPrefix(routeStr) @@ -686,14 +738,18 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { newRoutes[index] = route } - availableRoutes, err := m.GetAdvertisedRoutes() + availableRoutes, err := machine.GetAdvertisedRoutes() if err != nil { return err } for _, newRoute := range newRoutes { - if !containsIpPrefix(availableRoutes, newRoute) { - return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute) + if !containsIPPrefix(availableRoutes, newRoute) { + return fmt.Errorf( + "route (%s) is not available on node %s: %w", + machine.Name, + newRoute, errMachineRouteIsNotAvailable, + ) } } @@ -702,10 +758,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { return err } - m.EnabledRoutes = datatypes.JSON(routes) - h.db.Save(&m) + machine.EnabledRoutes = datatypes.JSON(routes) + h.db.Save(&machine) - err = h.RequestMapUpdates(m.NamespaceID) + err = h.RequestMapUpdates(machine.NamespaceID) if err != nil { return err } @@ -713,13 +769,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error { return nil } -func (m *Machine) RoutesToProto() (*v1.Routes, error) { - availableRoutes, err := m.GetAdvertisedRoutes() +func (machine *Machine) RoutesToProto() (*v1.Routes, error) { + availableRoutes, err := machine.GetAdvertisedRoutes() if err != nil { return nil, err } - enabledRoutes, err := m.GetEnabledRoutes() + enabledRoutes, err := machine.GetEnabledRoutes() if err != nil { return nil, err } diff --git a/machine_test.go b/machine_test.go index dfe84d3..cf36740 100644 --- a/machine_test.go +++ b/machine_test.go @@ -8,152 +8,159 @@ import ( ) func (s *Suite) TestGetMachine(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine("test", "testmachine") + _, err = app.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - m := &Machine{ + machine := &Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - h.db.Save(m) + app.db.Save(machine) - m1, err := h.GetMachine("test", "testmachine") + machineFromDB, err := app.GetMachine("test", "testmachine") c.Assert(err, check.IsNil) - _, err = m1.GetHostInfo() + _, err = machineFromDB.GetHostInfo() c.Assert(err, check.IsNil) } func (s *Suite) TestGetMachineByID(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachineByID(0) + _, err = app.GetMachineByID(0) c.Assert(err, check.NotNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) - m1, err := h.GetMachineByID(0) + machineByID, err := app.GetMachineByID(0) c.Assert(err, check.IsNil) - _, err = m1.GetHostInfo() + _, err = machineByID.GetHostInfo() c.Assert(err, check.IsNil) } func (s *Suite) TestDeleteMachine(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(1), } - h.db.Save(&m) - err = h.DeleteMachine(&m) + app.db.Save(&machine) + + err = app.DeleteMachine(&machine) c.Assert(err, check.IsNil) - v, err := h.getValue("namespaces_pending_updates") + + namespacesPendingUpdates, err := app.getValue("namespaces_pending_updates") c.Assert(err, check.IsNil) + names := []string{} - err = json.Unmarshal([]byte(v), &names) + err = json.Unmarshal([]byte(namespacesPendingUpdates), &names) c.Assert(err, check.IsNil) - c.Assert(names, check.DeepEquals, []string{n.Name}) - h.checkForNamespacesPendingUpdates() - v, _ = h.getValue("namespaces_pending_updates") - c.Assert(v, check.Equals, "") - _, err = h.GetMachine(n.Name, "testmachine") + c.Assert(names, check.DeepEquals, []string{namespace.Name}) + + app.checkForNamespacesPendingUpdates() + + namespacesPendingUpdates, _ = app.getValue("namespaces_pending_updates") + c.Assert(namespacesPendingUpdates, check.Equals, "") + _, err = app.GetMachine(namespace.Name, "testmachine") c.Assert(err, check.NotNil) } func (s *Suite) TestHardDeleteMachine(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine3", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(1), } - h.db.Save(&m) - err = h.HardDeleteMachine(&m) + app.db.Save(&machine) + + err = app.HardDeleteMachine(&machine) c.Assert(err, check.IsNil) - _, err = h.GetMachine(n.Name, "testmachine3") + + _, err = app.GetMachine(namespace.Name, "testmachine3") c.Assert(err, check.NotNil) } func (s *Suite) TestGetDirectPeers(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachineByID(0) + _, err = app.GetMachineByID(0) c.Assert(err, check.NotNil) - for i := 0; i <= 10; i++ { - m := Machine{ - ID: uint64(i), - MachineKey: "foo" + strconv.Itoa(i), - NodeKey: "bar" + strconv.Itoa(i), - DiscoKey: "faa" + strconv.Itoa(i), - Name: "testmachine" + strconv.Itoa(i), - NamespaceID: n.ID, + for index := 0; index <= 10; index++ { + machine := Machine{ + ID: uint64(index), + MachineKey: "foo" + strconv.Itoa(index), + NodeKey: "bar" + strconv.Itoa(index), + DiscoKey: "faa" + strconv.Itoa(index), + Name: "testmachine" + strconv.Itoa(index), + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) } - m1, err := h.GetMachineByID(0) + machine0ByID, err := app.GetMachineByID(0) c.Assert(err, check.IsNil) - _, err = m1.GetHostInfo() + _, err = machine0ByID.GetHostInfo() c.Assert(err, check.IsNil) - peers, err := h.getDirectPeers(m1) + peersOfMachine0, err := app.getDirectPeers(machine0ByID) c.Assert(err, check.IsNil) - c.Assert(len(peers), check.Equals, 9) - c.Assert(peers[0].Name, check.Equals, "testmachine2") - c.Assert(peers[5].Name, check.Equals, "testmachine7") - c.Assert(peers[8].Name, check.Equals, "testmachine10") + c.Assert(len(peersOfMachine0), check.Equals, 9) + c.Assert(peersOfMachine0[0].Name, check.Equals, "testmachine2") + c.Assert(peersOfMachine0[5].Name, check.Equals, "testmachine7") + c.Assert(peersOfMachine0[8].Name, check.Equals, "testmachine10") } diff --git a/metrics.go b/metrics.go index 0d3dca3..f0ce16e 100644 --- a/metrics.go +++ b/metrics.go @@ -32,7 +32,7 @@ var ( Name: "update_request_sent_to_node_total", Help: "The number of calls/messages issued on a specific nodes update channel", }, []string{"namespace", "machine", "status"}) - //TODO(kradalby): This is very debugging, we might want to remove it. + // TODO(kradalby): This is very debugging, we might want to remove it. updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, Name: "update_request_received_on_channel_total", diff --git a/namespaces.go b/namespaces.go index e5d1783..e512068 100644 --- a/namespaces.go +++ b/namespaces.go @@ -15,9 +15,9 @@ import ( ) const ( - errorNamespaceExists = Error("Namespace already exists") - errorNamespaceNotFound = Error("Namespace not found") - errorNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found") + errNamespaceExists = Error("Namespace already exists") + errNamespaceNotFound = Error("Namespace not found") + errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found") ) // Namespace is the way Headscale implements the concept of users in Tailscale @@ -30,51 +30,53 @@ type Namespace struct { } // CreateNamespace creates a new Namespace. Returns error if could not be created -// or another namespace already exists +// or another namespace already exists. func (h *Headscale) CreateNamespace(name string) (*Namespace, error) { - n := Namespace{} - if err := h.db.Where("name = ?", name).First(&n).Error; err == nil { - return nil, errorNamespaceExists + namespace := Namespace{} + if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil { + return nil, errNamespaceExists } - n.Name = name - if err := h.db.Create(&n).Error; err != nil { + namespace.Name = name + if err := h.db.Create(&namespace).Error; err != nil { log.Error(). Str("func", "CreateNamespace"). Err(err). Msg("Could not create row") + return nil, err } - return &n, nil + + return &namespace, nil } // DestroyNamespace destroys a Namespace. Returns error if the Namespace does // not exist or if there are machines associated with it. func (h *Headscale) DestroyNamespace(name string) error { - n, err := h.GetNamespace(name) + namespace, err := h.GetNamespace(name) if err != nil { - return errorNamespaceNotFound + return errNamespaceNotFound } - m, err := h.ListMachinesInNamespace(name) + machines, err := h.ListMachinesInNamespace(name) if err != nil { return err } - if len(m) > 0 { - return errorNamespaceNotEmptyOfNodes + if len(machines) > 0 { + return errNamespaceNotEmptyOfNodes } keys, err := h.ListPreAuthKeys(name) if err != nil { return err } - for _, p := range keys { - err = h.DestroyPreAuthKey(&p) + for _, key := range keys { + err = h.DestroyPreAuthKey(key) if err != nil { return err } } - if result := h.db.Unscoped().Delete(&n); result.Error != nil { + if result := h.db.Unscoped().Delete(&namespace); result.Error != nil { return result.Error } @@ -84,25 +86,25 @@ func (h *Headscale) DestroyNamespace(name string) error { // RenameNamespace renames a Namespace. Returns error if the Namespace does // not exist or if another Namespace exists with the new name. func (h *Headscale) RenameNamespace(oldName, newName string) error { - n, err := h.GetNamespace(oldName) + oldNamespace, err := h.GetNamespace(oldName) if err != nil { return err } _, err = h.GetNamespace(newName) if err == nil { - return errorNamespaceExists + return errNamespaceExists } - if !errors.Is(err, errorNamespaceNotFound) { + if !errors.Is(err, errNamespaceNotFound) { return err } - n.Name = newName + oldNamespace.Name = newName - if result := h.db.Save(&n); result.Error != nil { + if result := h.db.Save(&oldNamespace); result.Error != nil { return result.Error } - err = h.RequestMapUpdates(n.ID) + err = h.RequestMapUpdates(oldNamespace.ID) if err != nil { return err } @@ -110,39 +112,45 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error { return nil } -// GetNamespace fetches a namespace by name +// GetNamespace fetches a namespace by name. func (h *Headscale) GetNamespace(name string) (*Namespace, error) { - n := Namespace{} - if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, errorNamespaceNotFound + namespace := Namespace{} + if result := h.db.First(&namespace, "name = ?", name); errors.Is( + result.Error, + gorm.ErrRecordNotFound, + ) { + return nil, errNamespaceNotFound } - return &n, nil + + return &namespace, nil } -// ListNamespaces gets all the existing namespaces +// ListNamespaces gets all the existing namespaces. func (h *Headscale) ListNamespaces() ([]Namespace, error) { namespaces := []Namespace{} if err := h.db.Find(&namespaces).Error; err != nil { return nil, err } + return namespaces, nil } -// ListMachinesInNamespace gets all the nodes in a given namespace +// ListMachinesInNamespace gets all the nodes in a given namespace. func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) { - n, err := h.GetNamespace(name) + namespace, err := h.GetNamespace(name) if err != nil { return nil, err } machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { + if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil { return nil, err } + return machines, nil } -// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace +// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace. func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) { namespace, err := h.GetNamespace(name) if err != nil { @@ -155,48 +163,61 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error machines := []Machine{} for _, sharedMachine := range sharedMachines { - machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled + machine, err := h.GetMachineByID( + sharedMachine.MachineID, + ) // otherwise not everything comes filled if err != nil { return nil, err } machines = append(machines, *machine) } + return machines, nil } -// SetMachineNamespace assigns a Machine to a namespace -func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error { - n, err := h.GetNamespace(namespaceName) +// SetMachineNamespace assigns a Machine to a namespace. +func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error { + namespace, err := h.GetNamespace(namespaceName) if err != nil { return err } - m.NamespaceID = n.ID - h.db.Save(&m) + machine.NamespaceID = namespace.ID + h.db.Save(&machine) + return nil } -// RequestMapUpdates signals the KV worker to update the maps for this namespace +// TODO(kradalby): Remove the need for this. +// RequestMapUpdates signals the KV worker to update the maps for this namespace. func (h *Headscale) RequestMapUpdates(namespaceID uint) error { namespace := Namespace{} if err := h.db.First(&namespace, namespaceID).Error; err != nil { return err } - v, err := h.getValue("namespaces_pending_updates") - if err != nil || v == "" { - err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name)) + namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates") + if err != nil || namespacesPendingUpdates == "" { + err = h.setValue( + "namespaces_pending_updates", + fmt.Sprintf(`["%s"]`, namespace.Name), + ) if err != nil { return err } + return nil } names := []string{} - err = json.Unmarshal([]byte(v), &names) + err = json.Unmarshal([]byte(namespacesPendingUpdates), &names) if err != nil { - err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name)) + err = h.setValue( + "namespaces_pending_updates", + fmt.Sprintf(`["%s"]`, namespace.Name), + ) if err != nil { return err } + return nil } @@ -207,22 +228,24 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error { Str("func", "RequestMapUpdates"). Err(err). Msg("Could not marshal namespaces_pending_updates") + return err } + return h.setValue("namespaces_pending_updates", string(data)) } func (h *Headscale) checkForNamespacesPendingUpdates() { - v, err := h.getValue("namespaces_pending_updates") + namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates") if err != nil { return } - if v == "" { + if namespacesPendingUpdates == "" { return } namespaces := []string{} - err = json.Unmarshal([]byte(v), &namespaces) + err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces) if err != nil { return } @@ -233,24 +256,25 @@ func (h *Headscale) checkForNamespacesPendingUpdates() { Msg("Sending updates to nodes in namespacespace") h.setLastStateChangeToNow(namespace) } - newV, err := h.getValue("namespaces_pending_updates") + newPendingUpdateValue, err := h.getValue("namespaces_pending_updates") if err != nil { return } - if v == newV { // only clear when no changes, so we notified everybody + if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody err = h.setValue("namespaces_pending_updates", "") if err != nil { log.Error(). Str("func", "checkForNamespacesPendingUpdates"). Err(err). Msg("Could not save to KV") + return } } } func (n *Namespace) toUser() *tailcfg.User { - u := tailcfg.User{ + user := tailcfg.User{ ID: tailcfg.UserID(n.ID), LoginName: n.Name, DisplayName: n.Name, @@ -259,25 +283,27 @@ func (n *Namespace) toUser() *tailcfg.User { Logins: []tailcfg.LoginID{}, Created: time.Time{}, } - return &u + + return &user } func (n *Namespace) toLogin() *tailcfg.Login { - l := tailcfg.Login{ + login := tailcfg.Login{ ID: tailcfg.LoginID(n.ID), LoginName: n.Name, DisplayName: n.Name, ProfilePicURL: "", Domain: "headscale.net", } - return &l + + return &login } -func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile { +func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile { namespaceMap := make(map[string]Namespace) - namespaceMap[m.Namespace.Name] = m.Namespace - for _, p := range peers { - namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there + namespaceMap[machine.Namespace.Name] = machine.Namespace + for _, peer := range peers { + namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there } profiles := []tailcfg.UserProfile{} @@ -289,12 +315,13 @@ func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile DisplayName: namespace.Name, }) } + return profiles } func (n *Namespace) toProto() *v1.Namespace { return &v1.Namespace{ - Id: strconv.FormatUint(uint64(n.ID), 10), + Id: strconv.FormatUint(uint64(n.ID), Base10), Name: n.Name, CreatedAt: timestamppb.New(n.CreatedAt), } diff --git a/namespaces_test.go b/namespaces_test.go index 72193ee..bbae98f 100644 --- a/namespaces_test.go +++ b/namespaces_test.go @@ -7,207 +7,232 @@ import ( ) func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - c.Assert(n.Name, check.Equals, "test") + c.Assert(namespace.Name, check.Equals, "test") - ns, err := h.ListNamespaces() + namespaces, err := app.ListNamespaces() c.Assert(err, check.IsNil) - c.Assert(len(ns), check.Equals, 1) + c.Assert(len(namespaces), check.Equals, 1) - err = h.DestroyNamespace("test") + err = app.DestroyNamespace("test") c.Assert(err, check.IsNil) - _, err = h.GetNamespace("test") + _, err = app.GetNamespace("test") c.Assert(err, check.NotNil) } func (s *Suite) TestDestroyNamespaceErrors(c *check.C) { - err := h.DestroyNamespace("test") - c.Assert(err, check.Equals, errorNamespaceNotFound) + err := app.DestroyNamespace("test") + c.Assert(err, check.Equals, errNamespaceNotFound) - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - err = h.DestroyNamespace("test") + err = app.DestroyNamespace("test") c.Assert(err, check.IsNil) - result := h.db.Preload("Namespace").First(&pak, "key = ?", pak.Key) + result := app.db.Preload("Namespace").First(&pak, "key = ?", pak.Key) // destroying a namespace also deletes all associated preauthkeys c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) - n, err = h.CreateNamespace("test") + namespace, err = app.CreateNamespace("test") c.Assert(err, check.IsNil) - pak, err = h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err = app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) - err = h.DestroyNamespace("test") - c.Assert(err, check.Equals, errorNamespaceNotEmptyOfNodes) + err = app.DestroyNamespace("test") + c.Assert(err, check.Equals, errNamespaceNotEmptyOfNodes) } func (s *Suite) TestRenameNamespace(c *check.C) { - n, err := h.CreateNamespace("test") + namespaceTest, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - c.Assert(n.Name, check.Equals, "test") + c.Assert(namespaceTest.Name, check.Equals, "test") - ns, err := h.ListNamespaces() + namespaces, err := app.ListNamespaces() c.Assert(err, check.IsNil) - c.Assert(len(ns), check.Equals, 1) + c.Assert(len(namespaces), check.Equals, 1) - err = h.RenameNamespace("test", "test_renamed") + err = app.RenameNamespace("test", "test_renamed") c.Assert(err, check.IsNil) - _, err = h.GetNamespace("test") - c.Assert(err, check.Equals, errorNamespaceNotFound) + _, err = app.GetNamespace("test") + c.Assert(err, check.Equals, errNamespaceNotFound) - _, err = h.GetNamespace("test_renamed") + _, err = app.GetNamespace("test_renamed") c.Assert(err, check.IsNil) - err = h.RenameNamespace("test_does_not_exit", "test") - c.Assert(err, check.Equals, errorNamespaceNotFound) + err = app.RenameNamespace("test_does_not_exit", "test") + c.Assert(err, check.Equals, errNamespaceNotFound) - n2, err := h.CreateNamespace("test2") + namespaceTest2, err := app.CreateNamespace("test2") c.Assert(err, check.IsNil) - c.Assert(n2.Name, check.Equals, "test2") + c.Assert(namespaceTest2.Name, check.Equals, "test2") - err = h.RenameNamespace("test2", "test_renamed") - c.Assert(err, check.Equals, errorNamespaceExists) + err = app.RenameNamespace("test2", "test_renamed") + c.Assert(err, check.Equals, errNamespaceExists) } func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - n1, err := h.CreateNamespace("shared1") + namespaceShared1, err := app.CreateNamespace("shared1") c.Assert(err, check.IsNil) - n2, err := h.CreateNamespace("shared2") + namespaceShared2, err := app.CreateNamespace("shared2") c.Assert(err, check.IsNil) - n3, err := h.CreateNamespace("shared3") + namespaceShared3, err := app.CreateNamespace("shared3") c.Assert(err, check.IsNil) - pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) + preAuthKeyShared1, err := app.CreatePreAuthKey( + namespaceShared1.Name, + false, + false, + nil, + ) c.Assert(err, check.IsNil) - pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil) + preAuthKeyShared2, err := app.CreatePreAuthKey( + namespaceShared2.Name, + false, + false, + nil, + ) c.Assert(err, check.IsNil) - pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil) + preAuthKeyShared3, err := app.CreatePreAuthKey( + namespaceShared3.Name, + false, + false, + nil, + ) c.Assert(err, check.IsNil) - pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) + preAuthKey2Shared1, err := app.CreatePreAuthKey( + namespaceShared1.Name, + false, + false, + nil, + ) c.Assert(err, check.IsNil) - _, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1") + _, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1") c.Assert(err, check.NotNil) - m1 := &Machine{ + machineInShared1 := &Machine{ ID: 1, MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", Name: "test_get_shared_nodes_1", - NamespaceID: n1.ID, - Namespace: *n1, + NamespaceID: namespaceShared1.ID, + Namespace: *namespaceShared1, Registered: true, RegisterMethod: "authKey", IPAddress: "100.64.0.1", - AuthKeyID: uint(pak1n1.ID), + AuthKeyID: uint(preAuthKeyShared1.ID), } - h.db.Save(m1) + app.db.Save(machineInShared1) - _, err = h.GetMachine(n1.Name, m1.Name) + _, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name) c.Assert(err, check.IsNil) - m2 := &Machine{ + machineInShared2 := &Machine{ ID: 2, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", Name: "test_get_shared_nodes_2", - NamespaceID: n2.ID, - Namespace: *n2, + NamespaceID: namespaceShared2.ID, + Namespace: *namespaceShared2, Registered: true, RegisterMethod: "authKey", IPAddress: "100.64.0.2", - AuthKeyID: uint(pak2n2.ID), + AuthKeyID: uint(preAuthKeyShared2.ID), } - h.db.Save(m2) + app.db.Save(machineInShared2) - _, err = h.GetMachine(n2.Name, m2.Name) + _, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name) c.Assert(err, check.IsNil) - m3 := &Machine{ + machineInShared3 := &Machine{ ID: 3, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", Name: "test_get_shared_nodes_3", - NamespaceID: n3.ID, - Namespace: *n3, + NamespaceID: namespaceShared3.ID, + Namespace: *namespaceShared3, Registered: true, RegisterMethod: "authKey", IPAddress: "100.64.0.3", - AuthKeyID: uint(pak3n3.ID), + AuthKeyID: uint(preAuthKeyShared3.ID), } - h.db.Save(m3) + app.db.Save(machineInShared3) - _, err = h.GetMachine(n3.Name, m3.Name) + _, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name) c.Assert(err, check.IsNil) - m4 := &Machine{ + machine2InShared1 := &Machine{ ID: 4, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", Name: "test_get_shared_nodes_4", - NamespaceID: n1.ID, - Namespace: *n1, + NamespaceID: namespaceShared1.ID, + Namespace: *namespaceShared1, Registered: true, RegisterMethod: "authKey", IPAddress: "100.64.0.4", - AuthKeyID: uint(pak4n1.ID), + AuthKeyID: uint(preAuthKey2Shared1.ID), } - h.db.Save(m4) + app.db.Save(machine2InShared1) - err = h.AddSharedMachineToNamespace(m2, n1) + err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1) c.Assert(err, check.IsNil) - m1peers, err := h.getPeers(m1) + peersOfMachine1InShared1, err := app.getPeers(machineInShared1) c.Assert(err, check.IsNil) - userProfiles := getMapResponseUserProfiles(*m1, m1peers) + userProfiles := getMapResponseUserProfiles( + *machineInShared1, + peersOfMachine1InShared1, + ) log.Trace().Msgf("userProfiles %#v", userProfiles) c.Assert(len(userProfiles), check.Equals, 2) found := false - for _, up := range userProfiles { - if up.DisplayName == n1.Name { + for _, userProfiles := range userProfiles { + if userProfiles.DisplayName == namespaceShared1.Name { found = true + break } } c.Assert(found, check.Equals, true) found = false - for _, up := range userProfiles { - if up.DisplayName == n2.Name { + for _, userProfile := range userProfiles { + if userProfile.DisplayName == namespaceShared2.Name { found = true + break } } diff --git a/oidc.go b/oidc.go index 51c443d..07561e8 100644 --- a/oidc.go +++ b/oidc.go @@ -17,6 +17,12 @@ import ( "golang.org/x/oauth2" ) +const ( + oidcStateCacheExpiration = time.Minute * 5 + oidcStateCacheCleanupInterval = time.Minute * 10 + randomByteSize = 16 +) + type IDTokenClaims struct { Name string `json:"name,omitempty"` Groups []string `json:"groups,omitempty"` @@ -32,6 +38,7 @@ func (h *Headscale) initOIDC() error { if err != nil { log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) + return err } @@ -39,14 +46,20 @@ func (h *Headscale) initOIDC() error { ClientID: h.cfg.OIDC.ClientID, ClientSecret: h.cfg.OIDC.ClientSecret, Endpoint: h.oidcProvider.Endpoint(), - RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, + RedirectURL: fmt.Sprintf( + "%s/oidc/callback", + strings.TrimSuffix(h.cfg.ServerURL, "/"), + ), + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } } // init the state cache if it hasn't been already if h.oidcStateCache == nil { - h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10) + h.oidcStateCache = cache.New( + oidcStateCacheExpiration, + oidcStateCacheCleanupInterval, + ) } return nil @@ -54,50 +67,53 @@ func (h *Headscale) initOIDC() error { // RegisterOIDC redirects to the OIDC provider for authentication // Puts machine key in cache so the callback can retrieve it using the oidc state param -// Listens in /oidc/register/:mKey -func (h *Headscale) RegisterOIDC(c *gin.Context) { - mKeyStr := c.Param("mkey") +// Listens in /oidc/register/:mKey. +func (h *Headscale) RegisterOIDC(ctx *gin.Context) { + mKeyStr := ctx.Param("mkey") if mKeyStr == "" { - c.String(http.StatusBadRequest, "Wrong params") + ctx.String(http.StatusBadRequest, "Wrong params") + return } - b := make([]byte, 16) - _, err := rand.Read(b) - if err != nil { + randomBlob := make([]byte, randomByteSize) + if _, err := rand.Read(randomBlob); err != nil { log.Error().Msg("could not read 16 bytes from rand") - c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") + ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand") + return } - stateStr := hex.EncodeToString(b)[:32] + stateStr := hex.EncodeToString(randomBlob)[:32] // place the machine key into the state cache, so it can be retrieved later - h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5) + h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration) - authUrl := h.oauth2Config.AuthCodeURL(stateStr) - log.Debug().Msgf("Redirecting to %s for authentication", authUrl) + authURL := h.oauth2Config.AuthCodeURL(stateStr) + log.Debug().Msgf("Redirecting to %s for authentication", authURL) - c.Redirect(http.StatusFound, authUrl) + ctx.Redirect(http.StatusFound, authURL) } // OIDCCallback handles the callback from the OIDC endpoint // Retrieves the mkey from the state cache and adds the machine to the users email namespace // TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: Add groups information from OIDC tokens into machine HostInfo -// Listens in /oidc/callback -func (h *Headscale) OIDCCallback(c *gin.Context) { - code := c.Query("code") - state := c.Query("state") +// Listens in /oidc/callback. +func (h *Headscale) OIDCCallback(ctx *gin.Context) { + code := ctx.Query("code") + state := ctx.Query("state") if code == "" || state == "" { - c.String(http.StatusBadRequest, "Wrong params") + ctx.String(http.StatusBadRequest, "Wrong params") + return } oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) if err != nil { - c.String(http.StatusBadRequest, "Could not exchange code for token") + ctx.String(http.StatusBadRequest, "Could not exchange code for token") + return } @@ -105,7 +121,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) if !rawIDTokenOK { - c.String(http.StatusBadRequest, "Could not extract ID Token") + ctx.String(http.StatusBadRequest, "Could not extract ID Token") + return } @@ -113,21 +130,26 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { idToken, err := verifier.Verify(context.Background(), rawIDToken) if err != nil { - c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) + ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error()) + return } // TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc) - //userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token)) - //if err != nil { - // c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err)) - // return - //} + // userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token)) + // if err != nil { + // c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err)) + // return + // } // Extract custom claims var claims IDTokenClaims if err = idToken.Claims(&claims); err != nil { - c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err)) + ctx.String( + http.StatusBadRequest, + fmt.Sprintf("Failed to decode id token claims: %s", err), + ) + return } @@ -135,62 +157,80 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { mKeyIf, mKeyFound := h.oidcStateCache.Get(state) if !mKeyFound { - log.Error().Msg("requested machine state key expired before authorisation completed") - c.String(http.StatusBadRequest, "state has expired") + log.Error(). + Msg("requested machine state key expired before authorisation completed") + ctx.String(http.StatusBadRequest, "state has expired") + return } mKeyStr, mKeyOK := mKeyIf.(string) if !mKeyOK { log.Error().Msg("could not get machine key from cache") - c.String(http.StatusInternalServerError, "could not get machine key from cache") + ctx.String( + http.StatusInternalServerError, + "could not get machine key from cache", + ) + return } // retrieve machine information - m, err := h.GetMachineByMachineKey(mKeyStr) + machine, err := h.GetMachineByMachineKey(mKeyStr) if err != nil { log.Error().Msg("machine key not found in database") - c.String(http.StatusInternalServerError, "could not get machine info from database") + ctx.String( + http.StatusInternalServerError, + "could not get machine info from database", + ) + return } now := time.Now().UTC() - if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok { + if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok { // register the machine if it's new - if !m.Registered { - + if !machine.Registered { log.Debug().Msg("Registering new machine after successful callback") - ns, err := h.GetNamespace(nsName) + namespace, err := h.GetNamespace(namespaceName) if err != nil { - ns, err = h.CreateNamespace(nsName) + namespace, err = h.CreateNamespace(namespaceName) if err != nil { - log.Error().Msgf("could not create new namespace '%s'", claims.Email) - c.String(http.StatusInternalServerError, "could not create new namespace") + log.Error(). + Msgf("could not create new namespace '%s'", claims.Email) + ctx.String( + http.StatusInternalServerError, + "could not create new namespace", + ) + return } } ip, err := h.getAvailableIP() if err != nil { - c.String(http.StatusInternalServerError, "could not get an IP from the pool") + ctx.String( + http.StatusInternalServerError, + "could not get an IP from the pool", + ) + return } - m.IPAddress = ip.String() - m.NamespaceID = ns.ID - m.Registered = true - m.RegisterMethod = "oidc" - m.LastSuccessfulUpdate = &now - h.db.Save(&m) + machine.IPAddress = ip.String() + machine.NamespaceID = namespace.ID + machine.Registered = true + machine.RegisterMethod = "oidc" + machine.LastSuccessfulUpdate = &now + h.db.Save(&machine) } - h.updateMachineExpiry(m) + h.updateMachineExpiry(machine) - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(` + ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`

headscale

@@ -201,15 +241,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) { `, claims.Email))) - } log.Error(). Str("email", claims.Email). Str("username", claims.Username). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Email could not be mapped to a namespace") - c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace") + ctx.String( + http.StatusBadRequest, + "email from claim could not be mapped to a namespace", + ) } // getNamespaceFromEmail passes the users email through a list of "matchers" diff --git a/oidc_test.go b/oidc_test.go index c7a29ce..21a4357 100644 --- a/oidc_test.go +++ b/oidc_test.go @@ -145,29 +145,37 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) { }, } //nolint - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - h := &Headscale{ - cfg: tt.fields.cfg, - db: tt.fields.db, - dbString: tt.fields.dbString, - dbType: tt.fields.dbType, - dbDebug: tt.fields.dbDebug, - publicKey: tt.fields.publicKey, - privateKey: tt.fields.privateKey, - aclPolicy: tt.fields.aclPolicy, - aclRules: tt.fields.aclRules, - lastStateChange: tt.fields.lastStateChange, - oidcProvider: tt.fields.oidcProvider, - oauth2Config: tt.fields.oauth2Config, - oidcStateCache: tt.fields.oidcStateCache, + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + app := &Headscale{ + cfg: test.fields.cfg, + db: test.fields.db, + dbString: test.fields.dbString, + dbType: test.fields.dbType, + dbDebug: test.fields.dbDebug, + publicKey: test.fields.publicKey, + privateKey: test.fields.privateKey, + aclPolicy: test.fields.aclPolicy, + aclRules: test.fields.aclRules, + lastStateChange: test.fields.lastStateChange, + oidcProvider: test.fields.oidcProvider, + oauth2Config: test.fields.oauth2Config, + oidcStateCache: test.fields.oidcStateCache, } - got, got1 := h.getNamespaceFromEmail(tt.args.email) - if got != tt.want { - t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want) + got, got1 := app.getNamespaceFromEmail(test.args.email) + if got != test.want { + t.Errorf( + "Headscale.getNamespaceFromEmail() got = %v, want %v", + got, + test.want, + ) } - if got1 != tt.want1 { - t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1) + if got1 != test.want1 { + t.Errorf( + "Headscale.getNamespaceFromEmail() got1 = %v, want %v", + got1, + test.want1, + ) } }) } diff --git a/poll.go b/poll.go index 6a65280..9cf14e7 100644 --- a/poll.go +++ b/poll.go @@ -15,6 +15,11 @@ import ( "tailscale.com/types/wgkey" ) +const ( + keepAliveInterval = 60 * time.Second + updateCheckInterval = 10 * time.Second +) + // PollNetMapHandler takes care of /machine/:id/map // // This is the busiest endpoint, as it keeps the HTTP long poll that updates @@ -24,20 +29,21 @@ import ( // only after their first request (marked with the ReadOnly field). // // At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (h *Headscale) PollNetMapHandler(c *gin.Context) { +func (h *Headscale) PollNetMapHandler(ctx *gin.Context) { log.Trace(). Str("handler", "PollNetMap"). - Str("id", c.Param("id")). + Str("id", ctx.Param("id")). Msg("PollNetMapHandler called") - body, _ := io.ReadAll(c.Request.Body) - mKeyStr := c.Param("id") + body, _ := io.ReadAll(ctx.Request.Body) + mKeyStr := ctx.Param("id") mKey, err := wgkey.ParseHex(mKeyStr) if err != nil { log.Error(). Str("handler", "PollNetMap"). Err(err). Msg("Cannot parse client key") - c.String(http.StatusBadRequest, "") + ctx.String(http.StatusBadRequest, "") + return } req := tailcfg.MapRequest{} @@ -47,34 +53,36 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { Str("handler", "PollNetMap"). Err(err). Msg("Cannot decode message") - c.String(http.StatusBadRequest, "") + ctx.String(http.StatusBadRequest, "") + return } - m, err := h.GetMachineByMachineKey(mKey.HexString()) + machine, err := h.GetMachineByMachineKey(mKey.HexString()) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { log.Warn(). Str("handler", "PollNetMap"). Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString()) - c.String(http.StatusUnauthorized, "") + ctx.String(http.StatusUnauthorized, "") + return } log.Error(). Str("handler", "PollNetMap"). Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString()) - c.String(http.StatusInternalServerError, "") + ctx.String(http.StatusInternalServerError, "") } log.Trace(). Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). Msg("Found machine in database") hostinfo, _ := json.Marshal(req.Hostinfo) - m.Name = req.Hostinfo.Hostname - m.HostInfo = datatypes.JSON(hostinfo) - m.DiscoKey = wgkey.Key(req.DiscoKey).HexString() + machine.Name = req.Hostinfo.Hostname + machine.HostInfo = datatypes.JSON(hostinfo) + machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString() now := time.Now().UTC() // From Tailscale client: @@ -87,20 +95,21 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { // before their first real endpoint update. if !req.ReadOnly { endpoints, _ := json.Marshal(req.Endpoints) - m.Endpoints = datatypes.JSON(endpoints) - m.LastSeen = &now + machine.Endpoints = datatypes.JSON(endpoints) + machine.LastSeen = &now } - h.db.Save(&m) + h.db.Save(&machine) - data, err := h.getMapResponse(mKey, req, m) + data, err := h.getMapResponse(mKey, req, machine) if err != nil { log.Error(). Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). Err(err). Msg("Failed to get Map response") - c.String(http.StatusInternalServerError, ":(") + ctx.String(http.StatusInternalServerError, ":(") + return } @@ -111,8 +120,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 log.Debug(). Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). Bool("readOnly", req.ReadOnly). Bool("omitPeers", req.OmitPeers). Bool("stream", req.Stream). @@ -121,15 +130,16 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { if req.ReadOnly { log.Info(). Str("handler", "PollNetMap"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Client is starting up. Probably interested in a DERP map") - c.Data(200, "application/json; charset=utf-8", data) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) + return } // There has been an update to _any_ of the nodes that the other nodes would // need to know about - h.setLastStateChangeToNow(m.Namespace.Name) + h.setLastStateChangeToNow(machine.Namespace.Name) // The request is not ReadOnly, so we need to set up channels for updating // peers via longpoll @@ -137,8 +147,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { // Only create update channel if it has not been created log.Trace(). Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). Msg("Loading or creating update channel") updateChan := make(chan struct{}) @@ -152,46 +162,59 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { if req.OmitPeers && !req.Stream { log.Info(). Str("handler", "PollNetMap"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Client sent endpoint update and is ok with a response without peer list") - c.Data(200, "application/json; charset=utf-8", data) + ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) // It sounds like we should update the nodes when we have received a endpoint update // even tho the comments in the tailscale code dont explicitly say so. - updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").Inc() + updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update"). + Inc() go func() { updateChan <- struct{}{} }() + return } else if req.OmitPeers && req.Stream { log.Warn(). Str("handler", "PollNetMap"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Ignoring request, don't know how to handle it") - c.String(http.StatusBadRequest, "") + ctx.String(http.StatusBadRequest, "") + return } log.Info(). Str("handler", "PollNetMap"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Client is ready to access the tailnet") log.Info(). Str("handler", "PollNetMap"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Sending initial map") go func() { pollDataChan <- data }() log.Info(). Str("handler", "PollNetMap"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Notifying peers") - updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").Inc() + updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update"). + Inc() go func() { updateChan <- struct{}{} }() - h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive) + h.PollNetMapStream( + ctx, + machine, + req, + mKey, + pollDataChan, + keepAliveChan, + updateChan, + cancelKeepAlive, + ) log.Trace(). Str("handler", "PollNetMap"). - Str("id", c.Param("id")). - Str("machine", m.Name). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). Msg("Finished stream, closing PollNetMap session") } @@ -199,165 +222,181 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) { // stream logic, ensuring we communicate updates and data // to the connected clients. func (h *Headscale) PollNetMapStream( - c *gin.Context, - m *Machine, - req tailcfg.MapRequest, - mKey wgkey.Key, + ctx *gin.Context, + machine *Machine, + mapRequest tailcfg.MapRequest, + machineKey wgkey.Key, pollDataChan chan []byte, keepAliveChan chan []byte, updateChan chan struct{}, cancelKeepAlive chan struct{}, ) { - go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m) + go h.scheduledPollWorker( + cancelKeepAlive, + updateChan, + keepAliveChan, + machineKey, + mapRequest, + machine, + ) - c.Stream(func(w io.Writer) bool { + ctx.Stream(func(writer io.Writer) bool { log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Waiting for data to stream...") log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) select { case data := <-pollDataChan: log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "pollData"). Int("bytes", len(data)). Msg("Sending data received via pollData channel") - _, err := w.Write(data) + _, err := writer.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "pollData"). Err(err). Msg("Cannot write data") + return false } log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "pollData"). Int("bytes", len(data)). Msg("Data from pollData channel written successfully") // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachine(m) + err = h.UpdateMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "pollData"). Err(err). Msg("Cannot update machine from database") } now := time.Now().UTC() - m.LastSeen = &now + machine.LastSeen = &now - lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix())) - m.LastSuccessfulUpdate = &now + lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name). + Set(float64(now.Unix())) + machine.LastSuccessfulUpdate = &now - h.db.Save(&m) + h.db.Save(&machine) log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "pollData"). Int("bytes", len(data)). Msg("Machine entry in database updated successfully after sending pollData") + return true case data := <-keepAliveChan: log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "keepAlive"). Int("bytes", len(data)). Msg("Sending keep alive message") - _, err := w.Write(data) + _, err := writer.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "keepAlive"). Err(err). Msg("Cannot write keep alive message") + return false } log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "keepAlive"). Int("bytes", len(data)). Msg("Keep alive sent successfully") // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachine(m) + err = h.UpdateMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "keepAlive"). Err(err). Msg("Cannot update machine from database") } now := time.Now().UTC() - m.LastSeen = &now - h.db.Save(&m) + machine.LastSeen = &now + h.db.Save(&machine) log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "keepAlive"). Int("bytes", len(data)). Msg("Machine updated successfully after sending keep alive") + return true case <-updateChan: log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "update"). Msg("Received a request for update") - updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).Inc() - if h.isOutdated(m) { + updateRequestsReceivedOnChannel.WithLabelValues(machine.Name, machine.Namespace.Name). + Inc() + if h.isOutdated(machine) { log.Debug(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). - Time("last_successful_update", *m.LastSuccessfulUpdate). - Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). - Msgf("There has been updates since the last successful update to %s", m.Name) - data, err := h.getMapResponse(mKey, req, m) + Str("machine", machine.Name). + Time("last_successful_update", *machine.LastSuccessfulUpdate). + Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)). + Msgf("There has been updates since the last successful update to %s", machine.Name) + data, err := h.getMapResponse(machineKey, mapRequest, machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "update"). Err(err). Msg("Could not get the map update") } - _, err = w.Write(data) + _, err = writer.Write(data) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "update"). Err(err). Msg("Could not write the map response") - updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").Inc() + updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "failed"). + Inc() + return false } log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "update"). Msg("Updated Map has been sent") - updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").Inc() + updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "success"). + Inc() // Keep track of the last successful update, // we sometimes end in a state were the update @@ -366,77 +405,79 @@ func (h *Headscale) PollNetMapStream( // TODO(kradalby): Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err = h.UpdateMachine(m) + err = h.UpdateMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "update"). Err(err). Msg("Cannot update machine from database") } now := time.Now().UTC() - lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix())) - m.LastSuccessfulUpdate = &now + lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name). + Set(float64(now.Unix())) + machine.LastSuccessfulUpdate = &now - h.db.Save(&m) + h.db.Save(&machine) } else { log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). - Time("last_successful_update", *m.LastSuccessfulUpdate). - Time("last_state_change", h.getLastStateChange(m.Namespace.Name)). - Msgf("%s is up to date", m.Name) + Str("machine", machine.Name). + Time("last_successful_update", *machine.LastSuccessfulUpdate). + Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)). + Msgf("%s is up to date", machine.Name) } + return true - case <-c.Request.Context().Done(): + case <-ctx.Request.Context().Done(): log.Info(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("The client has closed the connection") // TODO: Abstract away all the database calls, this can cause race conditions // when an outdated machine object is kept alive, e.g. db is update from // command line, but then overwritten. - err := h.UpdateMachine(m) + err := h.UpdateMachine(machine) if err != nil { log.Error(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "Done"). Err(err). Msg("Cannot update machine from database") } now := time.Now().UTC() - m.LastSeen = &now - h.db.Save(&m) + machine.LastSeen = &now + h.db.Save(&machine) log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "Done"). Msg("Cancelling keepAlive channel") cancelKeepAlive <- struct{}{} log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "Done"). Msg("Closing update channel") - //h.closeUpdateChannel(m) + // h.closeUpdateChannel(m) close(updateChan) log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "Done"). Msg("Closing pollData channel") close(pollDataChan) log.Trace(). Str("handler", "PollNetMapStream"). - Str("machine", m.Name). + Str("machine", machine.Name). Str("channel", "Done"). Msg("Closing keepAliveChan channel") close(keepAliveChan) @@ -450,12 +491,12 @@ func (h *Headscale) scheduledPollWorker( cancelChan <-chan struct{}, updateChan chan<- struct{}, keepAliveChan chan<- []byte, - mKey wgkey.Key, - req tailcfg.MapRequest, - m *Machine, + machineKey wgkey.Key, + mapRequest tailcfg.MapRequest, + machine *Machine, ) { - keepAliveTicker := time.NewTicker(60 * time.Second) - updateCheckerTicker := time.NewTicker(10 * time.Second) + keepAliveTicker := time.NewTicker(keepAliveInterval) + updateCheckerTicker := time.NewTicker(updateCheckInterval) for { select { @@ -463,27 +504,29 @@ func (h *Headscale) scheduledPollWorker( return case <-keepAliveTicker.C: - data, err := h.getMapKeepAliveResponse(mKey, req, m) + data, err := h.getMapKeepAliveResponse(machineKey, mapRequest) if err != nil { log.Error(). Str("func", "keepAlive"). Err(err). Msg("Error generating the keep alive msg") + return } log.Debug(). Str("func", "keepAlive"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Sending keepalive") keepAliveChan <- data case <-updateCheckerTicker.C: log.Debug(). Str("func", "scheduledPollWorker"). - Str("machine", m.Name). + Str("machine", machine.Name). Msg("Sending update request") - updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").Inc() + updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "scheduled-update"). + Inc() updateChan <- struct{}{} } } diff --git a/preauth_keys.go b/preauth_keys.go index 41b10e3..50bc474 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -7,19 +7,19 @@ import ( "strconv" "time" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" - - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" ) const ( - errorAuthKeyNotFound = Error("AuthKey not found") - errorAuthKeyExpired = Error("AuthKey expired") + errPreAuthKeyNotFound = Error("AuthKey not found") + errPreAuthKeyExpired = Error("AuthKey expired") errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used") + errNamespaceMismatch = Error("namespace mismatch") ) -// PreAuthKey describes a pre-authorization key usable in a particular namespace +// PreAuthKey describes a pre-authorization key usable in a particular namespace. type PreAuthKey struct { ID uint64 `gorm:"primary_key"` Key string @@ -33,14 +33,14 @@ type PreAuthKey struct { Expiration *time.Time } -// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it +// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it. func (h *Headscale) CreatePreAuthKey( namespaceName string, reusable bool, ephemeral bool, expiration *time.Time, ) (*PreAuthKey, error) { - n, err := h.GetNamespace(namespaceName) + namespace, err := h.GetNamespace(namespaceName) if err != nil { return nil, err } @@ -51,35 +51,36 @@ func (h *Headscale) CreatePreAuthKey( return nil, err } - k := PreAuthKey{ + key := PreAuthKey{ Key: kstr, - NamespaceID: n.ID, - Namespace: *n, + NamespaceID: namespace.ID, + Namespace: *namespace, Reusable: reusable, Ephemeral: ephemeral, CreatedAt: &now, Expiration: expiration, } - h.db.Save(&k) + h.db.Save(&key) - return &k, nil + return &key, nil } -// ListPreAuthKeys returns the list of PreAuthKeys for a namespace +// ListPreAuthKeys returns the list of PreAuthKeys for a namespace. func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) { - n, err := h.GetNamespace(namespaceName) + namespace, err := h.GetNamespace(namespaceName) if err != nil { return nil, err } keys := []PreAuthKey{} - if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil { + if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil { return nil, err } + return keys, nil } -// GetPreAuthKey returns a PreAuthKey for a given key +// GetPreAuthKey returns a PreAuthKey for a given key. func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) { pak, err := h.checkKeyValidity(key) if err != nil { @@ -87,7 +88,7 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er } if pak.Namespace.Name != namespace { - return nil, errors.New("Namespace mismatch") + return nil, errNamespaceMismatch } return pak, nil @@ -95,32 +96,36 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. -func (h *Headscale) DestroyPreAuthKey(pak *PreAuthKey) error { - if result := h.db.Unscoped().Delete(&pak); result.Error != nil { +func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error { + if result := h.db.Unscoped().Delete(pak); result.Error != nil { return result.Error } return nil } -// MarkExpirePreAuthKey marks a PreAuthKey as expired +// MarkExpirePreAuthKey marks a PreAuthKey as expired. func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } + return nil } // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node -// If returns no error and a PreAuthKey, it can be used +// If returns no error and a PreAuthKey, it can be used. func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { pak := PreAuthKey{} - if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, errorAuthKeyNotFound + if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is( + result.Error, + gorm.ErrRecordNotFound, + ) { + return nil, errPreAuthKeyNotFound } if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { - return nil, errorAuthKeyExpired + return nil, errPreAuthKeyExpired } if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before @@ -145,13 +150,14 @@ func (h *Headscale) generateKey() (string, error) { if _, err := rand.Read(bytes); err != nil { return "", err } + return hex.EncodeToString(bytes), nil } func (key *PreAuthKey) toProto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ Namespace: key.Namespace.Name, - Id: strconv.FormatUint(key.ID, 10), + Id: strconv.FormatUint(key.ID, Base10), Key: key.Key, Ephemeral: key.Ephemeral, Reusable: key.Reusable, diff --git a/preauth_keys_test.go b/preauth_keys_test.go index dceec00..fd0feb0 100644 --- a/preauth_keys_test.go +++ b/preauth_keys_test.go @@ -7,189 +7,189 @@ import ( ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := h.CreatePreAuthKey("bogus", true, false, nil) + _, err := app.CreatePreAuthKey("bogus", true, false, nil) c.Assert(err, check.NotNil) - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - k, err := h.CreatePreAuthKey(n.Name, true, false, nil) + key, err := app.CreatePreAuthKey(namespace.Name, true, false, nil) c.Assert(err, check.IsNil) // Did we get a valid key? - c.Assert(k.Key, check.NotNil) - c.Assert(len(k.Key), check.Equals, 48) + c.Assert(key.Key, check.NotNil) + c.Assert(len(key.Key), check.Equals, 48) // Make sure the Namespace association is populated - c.Assert(k.Namespace.Name, check.Equals, n.Name) + c.Assert(key.Namespace.Name, check.Equals, namespace.Name) - _, err = h.ListPreAuthKeys("bogus") + _, err = app.ListPreAuthKeys("bogus") c.Assert(err, check.NotNil) - keys, err := h.ListPreAuthKeys(n.Name) + keys, err := app.ListPreAuthKeys(namespace.Name) c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) // Make sure the Namespace association is populated - c.Assert((keys)[0].Namespace.Name, check.Equals, n.Name) + c.Assert((keys)[0].Namespace.Name, check.Equals, namespace.Name) } func (*Suite) TestExpiredPreAuthKey(c *check.C) { - n, err := h.CreateNamespace("test2") + namespace, err := app.CreateNamespace("test2") c.Assert(err, check.IsNil) now := time.Now() - pak, err := h.CreatePreAuthKey(n.Name, true, false, &now) + pak, err := app.CreatePreAuthKey(namespace.Name, true, false, &now) c.Assert(err, check.IsNil) - p, err := h.checkKeyValidity(pak.Key) - c.Assert(err, check.Equals, errorAuthKeyExpired) - c.Assert(p, check.IsNil) + key, err := app.checkKeyValidity(pak.Key) + c.Assert(err, check.Equals, errPreAuthKeyExpired) + c.Assert(key, check.IsNil) } func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { - p, err := h.checkKeyValidity("potatoKey") - c.Assert(err, check.Equals, errorAuthKeyNotFound) - c.Assert(p, check.IsNil) + key, err := app.checkKeyValidity("potatoKey") + c.Assert(err, check.Equals, errPreAuthKeyNotFound) + c.Assert(key, check.IsNil) } func (*Suite) TestValidateKeyOk(c *check.C) { - n, err := h.CreateNamespace("test3") + namespace, err := app.CreateNamespace("test3") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, true, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil) c.Assert(err, check.IsNil) - p, err := h.checkKeyValidity(pak.Key) + key, err := app.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) - c.Assert(p.ID, check.Equals, pak.ID) + c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestAlreadyUsedKey(c *check.C) { - n, err := h.CreateNamespace("test4") + namespace, err := app.CreateNamespace("test4") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testest", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) - p, err := h.checkKeyValidity(pak.Key) + key, err := app.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed) - c.Assert(p, check.IsNil) + c.Assert(key, check.IsNil) } func (*Suite) TestReusableBeingUsedKey(c *check.C) { - n, err := h.CreateNamespace("test5") + namespace, err := app.CreateNamespace("test5") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, true, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil) c.Assert(err, check.IsNil) - m := Machine{ + machine := Machine{ ID: 1, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testest", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) - p, err := h.checkKeyValidity(pak.Key) + key, err := app.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) - c.Assert(p.ID, check.Equals, pak.ID) + c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { - n, err := h.CreateNamespace("test6") + namespace, err := app.CreateNamespace("test6") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - p, err := h.checkKeyValidity(pak.Key) + key, err := app.checkKeyValidity(pak.Key) c.Assert(err, check.IsNil) - c.Assert(p.ID, check.Equals, pak.ID) + c.Assert(key.ID, check.Equals, pak.ID) } func (*Suite) TestEphemeralKey(c *check.C) { - n, err := h.CreateNamespace("test7") + namespace, err := app.CreateNamespace("test7") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, true, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, true, nil) c.Assert(err, check.IsNil) now := time.Now() - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testest", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", LastSeen: &now, AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) - _, err = h.checkKeyValidity(pak.Key) + _, err = app.checkKeyValidity(pak.Key) // Ephemeral keys are by definition reusable c.Assert(err, check.IsNil) - _, err = h.GetMachine("test7", "testest") + _, err = app.GetMachine("test7", "testest") c.Assert(err, check.IsNil) - h.expireEphemeralNodesWorker() + app.expireEphemeralNodesWorker() // The machine record should have been deleted - _, err = h.GetMachine("test7", "testest") + _, err = app.GetMachine("test7", "testest") c.Assert(err, check.NotNil) } func (*Suite) TestExpirePreauthKey(c *check.C) { - n, err := h.CreateNamespace("test3") + namespace, err := app.CreateNamespace("test3") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, true, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) - err = h.ExpirePreAuthKey(pak) + err = app.ExpirePreAuthKey(pak) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.NotNil) - p, err := h.checkKeyValidity(pak.Key) - c.Assert(err, check.Equals, errorAuthKeyExpired) - c.Assert(p, check.IsNil) + key, err := app.checkKeyValidity(pak.Key) + c.Assert(err, check.Equals, errPreAuthKeyExpired) + c.Assert(key, check.IsNil) } func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { - n, err := h.CreateNamespace("test6") + namespace, err := app.CreateNamespace("test6") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) pak.Used = true - h.db.Save(&pak) + app.db.Save(&pak) - _, err = h.checkKeyValidity(pak.Key) + _, err = app.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed) } diff --git a/proto/headscale/v1/headscale.proto b/proto/headscale/v1/headscale.proto index 26fe2f9..e7a63fc 100644 --- a/proto/headscale/v1/headscale.proto +++ b/proto/headscale/v1/headscale.proto @@ -12,115 +12,115 @@ import "headscale/v1/routes.proto"; service HeadscaleService { // --- Namespace start --- - rpc GetNamespace(GetNamespaceRequest) returns(GetNamespaceResponse) { - option(google.api.http) = { - get : "/api/v1/namespace/{name}" + rpc GetNamespace(GetNamespaceRequest) returns (GetNamespaceResponse) { + option (google.api.http) = { + get: "/api/v1/namespace/{name}" }; } - rpc CreateNamespace(CreateNamespaceRequest) returns(CreateNamespaceResponse) { - option(google.api.http) = { - post : "/api/v1/namespace" - body : "*" + rpc CreateNamespace(CreateNamespaceRequest) returns (CreateNamespaceResponse) { + option (google.api.http) = { + post: "/api/v1/namespace" + body: "*" }; } - rpc RenameNamespace(RenameNamespaceRequest) returns(RenameNamespaceResponse) { - option(google.api.http) = { - post : "/api/v1/namespace/{old_name}/rename/{new_name}" + rpc RenameNamespace(RenameNamespaceRequest) returns (RenameNamespaceResponse) { + option (google.api.http) = { + post: "/api/v1/namespace/{old_name}/rename/{new_name}" }; } - rpc DeleteNamespace(DeleteNamespaceRequest) returns(DeleteNamespaceResponse) { - option(google.api.http) = { - delete : "/api/v1/namespace/{name}" + rpc DeleteNamespace(DeleteNamespaceRequest) returns (DeleteNamespaceResponse) { + option (google.api.http) = { + delete: "/api/v1/namespace/{name}" }; } - rpc ListNamespaces(ListNamespacesRequest) returns(ListNamespacesResponse) { - option(google.api.http) = { - get : "/api/v1/namespace" + rpc ListNamespaces(ListNamespacesRequest) returns (ListNamespacesResponse) { + option (google.api.http) = { + get: "/api/v1/namespace" }; } // --- Namespace end --- // --- PreAuthKeys start --- - rpc CreatePreAuthKey(CreatePreAuthKeyRequest) returns(CreatePreAuthKeyResponse) { - option(google.api.http) = { - post : "/api/v1/preauthkey" - body : "*" + rpc CreatePreAuthKey(CreatePreAuthKeyRequest) returns (CreatePreAuthKeyResponse) { + option (google.api.http) = { + post: "/api/v1/preauthkey" + body: "*" }; } - rpc ExpirePreAuthKey(ExpirePreAuthKeyRequest) returns(ExpirePreAuthKeyResponse) { - option(google.api.http) = { - post : "/api/v1/preauthkey/expire" - body : "*" + rpc ExpirePreAuthKey(ExpirePreAuthKeyRequest) returns (ExpirePreAuthKeyResponse) { + option (google.api.http) = { + post: "/api/v1/preauthkey/expire" + body: "*" }; } - rpc ListPreAuthKeys(ListPreAuthKeysRequest) returns(ListPreAuthKeysResponse) { - option(google.api.http) = { - get : "/api/v1/preauthkey" + rpc ListPreAuthKeys(ListPreAuthKeysRequest) returns (ListPreAuthKeysResponse) { + option (google.api.http) = { + get: "/api/v1/preauthkey" }; } // --- PreAuthKeys end --- // --- Machine start --- - rpc DebugCreateMachine(DebugCreateMachineRequest) returns(DebugCreateMachineResponse) { - option(google.api.http) = { - post : "/api/v1/debug/machine" - body : "*" + rpc DebugCreateMachine(DebugCreateMachineRequest) returns (DebugCreateMachineResponse) { + option (google.api.http) = { + post: "/api/v1/debug/machine" + body: "*" }; } - rpc GetMachine(GetMachineRequest) returns(GetMachineResponse) { - option(google.api.http) = { - get : "/api/v1/machine/{machine_id}" + rpc GetMachine(GetMachineRequest) returns (GetMachineResponse) { + option (google.api.http) = { + get: "/api/v1/machine/{machine_id}" }; } - rpc RegisterMachine(RegisterMachineRequest) returns(RegisterMachineResponse) { - option(google.api.http) = { - post : "/api/v1/machine/register" + rpc RegisterMachine(RegisterMachineRequest) returns (RegisterMachineResponse) { + option (google.api.http) = { + post: "/api/v1/machine/register" }; } - rpc DeleteMachine(DeleteMachineRequest) returns(DeleteMachineResponse) { - option(google.api.http) = { - delete : "/api/v1/machine/{machine_id}" + rpc DeleteMachine(DeleteMachineRequest) returns (DeleteMachineResponse) { + option (google.api.http) = { + delete: "/api/v1/machine/{machine_id}" }; } - rpc ListMachines(ListMachinesRequest) returns(ListMachinesResponse) { - option(google.api.http) = { - get : "/api/v1/machine" + rpc ListMachines(ListMachinesRequest) returns (ListMachinesResponse) { + option (google.api.http) = { + get: "/api/v1/machine" }; } - rpc ShareMachine(ShareMachineRequest) returns(ShareMachineResponse) { - option(google.api.http) = { - post : "/api/v1/machine/{machine_id}/share/{namespace}" + rpc ShareMachine(ShareMachineRequest) returns (ShareMachineResponse) { + option (google.api.http) = { + post: "/api/v1/machine/{machine_id}/share/{namespace}" }; } - rpc UnshareMachine(UnshareMachineRequest) returns(UnshareMachineResponse) { - option(google.api.http) = { - post : "/api/v1/machine/{machine_id}/unshare/{namespace}" + rpc UnshareMachine(UnshareMachineRequest) returns (UnshareMachineResponse) { + option (google.api.http) = { + post: "/api/v1/machine/{machine_id}/unshare/{namespace}" }; } // --- Machine end --- // --- Route start --- - rpc GetMachineRoute(GetMachineRouteRequest) returns(GetMachineRouteResponse) { - option(google.api.http) = { - get : "/api/v1/machine/{machine_id}/routes" + rpc GetMachineRoute(GetMachineRouteRequest) returns (GetMachineRouteResponse) { + option (google.api.http) = { + get: "/api/v1/machine/{machine_id}/routes" }; } - rpc EnableMachineRoutes(EnableMachineRoutesRequest) returns(EnableMachineRoutesResponse) { - option(google.api.http) = { - post : "/api/v1/machine/{machine_id}/routes" + rpc EnableMachineRoutes(EnableMachineRoutesRequest) returns (EnableMachineRoutesResponse) { + option (google.api.http) = { + post: "/api/v1/machine/{machine_id}/routes" }; } // --- Route end --- diff --git a/routes.go b/routes.go index f07b709..448095a 100644 --- a/routes.go +++ b/routes.go @@ -2,38 +2,48 @@ package headscale import ( "encoding/json" - "fmt" "gorm.io/datatypes" "inet.af/netaddr" ) +const ( + errRouteIsNotAvailable = Error("route is not available") +) + // Deprecated: use machine function instead // GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by -// namespace and node name) -func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) { - m, err := h.GetMachine(namespace, nodeName) +// namespace and node name). +func (h *Headscale) GetAdvertisedNodeRoutes( + namespace string, + nodeName string, +) (*[]netaddr.IPPrefix, error) { + machine, err := h.GetMachine(namespace, nodeName) if err != nil { return nil, err } - hostInfo, err := m.GetHostInfo() + hostInfo, err := machine.GetHostInfo() if err != nil { return nil, err } + return &hostInfo.RoutableIPs, nil } // Deprecated: use machine function instead // GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by -// namespace and node name) -func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) { - m, err := h.GetMachine(namespace, nodeName) +// namespace and node name). +func (h *Headscale) GetEnabledNodeRoutes( + namespace string, + nodeName string, +) ([]netaddr.IPPrefix, error) { + machine, err := h.GetMachine(namespace, nodeName) if err != nil { return nil, err } - data, err := m.EnabledRoutes.MarshalJSON() + data, err := machine.EnabledRoutes.MarshalJSON() if err != nil { return nil, err } @@ -57,8 +67,12 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n } // Deprecated: use machine function instead -// IsNodeRouteEnabled checks if a certain route has been enabled -func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool { +// IsNodeRouteEnabled checks if a certain route has been enabled. +func (h *Headscale) IsNodeRouteEnabled( + namespace string, + nodeName string, + routeStr string, +) bool { route, err := netaddr.ParseIPPrefix(routeStr) if err != nil { return false @@ -74,14 +88,19 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS return true } } + return false } // Deprecated: use EnableRoute in machine.go // EnableNodeRoute enables a subnet route advertised by a node (identified by -// namespace and node name) -func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error { - m, err := h.GetMachine(namespace, nodeName) +// namespace and node name). +func (h *Headscale) EnableNodeRoute( + namespace string, + nodeName string, + routeStr string, +) error { + machine, err := h.GetMachine(namespace, nodeName) if err != nil { return err } @@ -113,7 +132,7 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr } if !available { - return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr) + return errRouteIsNotAvailable } routes, err := json.Marshal(enabledRoutes) @@ -121,10 +140,10 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr return err } - m.EnabledRoutes = datatypes.JSON(routes) - h.db.Save(&m) + machine.EnabledRoutes = datatypes.JSON(routes) + h.db.Save(&machine) - err = h.RequestMapUpdates(m.NamespaceID) + err = h.RequestMapUpdates(machine.NamespaceID) if err != nil { return err } diff --git a/routes_test.go b/routes_test.go index ad16d21..18cb0ce 100644 --- a/routes_test.go +++ b/routes_test.go @@ -10,57 +10,60 @@ import ( ) func (s *Suite) TestGetRoutes(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine("test", "test_get_route_machine") + _, err = app.GetMachine("test", "test_get_route_machine") c.Assert(err, check.NotNil) route, err := netaddr.ParseIPPrefix("10.0.0.0/24") c.Assert(err, check.IsNil) - hi := tailcfg.Hostinfo{ + hostInfo := tailcfg.Hostinfo{ RoutableIPs: []netaddr.IPPrefix{route}, } - hostinfo, err := json.Marshal(hi) + hostinfo, err := json.Marshal(hostInfo) c.Assert(err, check.IsNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "test_get_route_machine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostinfo), } - h.db.Save(&m) + app.db.Save(&machine) - r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine") + advertisedRoutes, err := app.GetAdvertisedNodeRoutes( + "test", + "test_get_route_machine", + ) c.Assert(err, check.IsNil) - c.Assert(len(*r), check.Equals, 1) + c.Assert(len(*advertisedRoutes), check.Equals, 1) - err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24") + err = app.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24") c.Assert(err, check.NotNil) - err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24") + err = app.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24") c.Assert(err, check.IsNil) } func (s *Suite) TestGetEnableRoutes(c *check.C) { - n, err := h.CreateNamespace("test") + namespace, err := app.CreateNamespace("test") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine("test", "test_enable_route_machine") + _, err = app.GetMachine("test", "test_enable_route_machine") c.Assert(err, check.NotNil) route, err := netaddr.ParseIPPrefix( @@ -73,56 +76,68 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { ) c.Assert(err, check.IsNil) - hi := tailcfg.Hostinfo{ + hostInfo := tailcfg.Hostinfo{ RoutableIPs: []netaddr.IPPrefix{route, route2}, } - hostinfo, err := json.Marshal(hi) + hostinfo, err := json.Marshal(hostInfo) c.Assert(err, check.IsNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "test_enable_route_machine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), HostInfo: datatypes.JSON(hostinfo), } - h.db.Save(&m) + app.db.Save(&machine) - availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine") + availableRoutes, err := app.GetAdvertisedNodeRoutes( + "test", + "test_enable_route_machine", + ) c.Assert(err, check.IsNil) c.Assert(len(*availableRoutes), check.Equals, 2) - enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine") + noEnabledRoutes, err := app.GetEnabledNodeRoutes( + "test", + "test_enable_route_machine", + ) c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes), check.Equals, 0) + c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24") + err = app.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24") c.Assert(err, check.NotNil) - err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24") + err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24") c.Assert(err, check.IsNil) - enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine") + enabledRoutes, err := app.GetEnabledNodeRoutes("test", "test_enable_route_machine") c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes1), check.Equals, 1) + c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24") + err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24") c.Assert(err, check.IsNil) - enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine") + enableRoutesAfterDoubleApply, err := app.GetEnabledNodeRoutes( + "test", + "test_enable_route_machine", + ) c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes2), check.Equals, 1) + c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25") + err = app.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25") c.Assert(err, check.IsNil) - enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine") + enabledRoutesWithAdditionalRoute, err := app.GetEnabledNodeRoutes( + "test", + "test_enable_route_machine", + ) c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes3), check.Equals, 2) + c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) } diff --git a/sharing.go b/sharing.go index 5f6a8f4..be1689d 100644 --- a/sharing.go +++ b/sharing.go @@ -2,11 +2,13 @@ package headscale import "gorm.io/gorm" -const errorSameNamespace = Error("Destination namespace same as origin") -const errorMachineAlreadyShared = Error("Node already shared to this namespace") -const errorMachineNotShared = Error("Machine not shared to this namespace") +const ( + errSameNamespace = Error("Destination namespace same as origin") + errMachineAlreadyShared = Error("Node already shared to this namespace") + errMachineNotShared = Error("Machine not shared to this namespace") +) -// SharedMachine is a join table to support sharing nodes between namespaces +// SharedMachine is a join table to support sharing nodes between namespaces. type SharedMachine struct { gorm.Model MachineID uint64 @@ -15,49 +17,57 @@ type SharedMachine struct { Namespace Namespace } -// AddSharedMachineToNamespace adds a machine as a shared node to a namespace -func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error { - if m.NamespaceID == ns.ID { - return errorSameNamespace +// AddSharedMachineToNamespace adds a machine as a shared node to a namespace. +func (h *Headscale) AddSharedMachineToNamespace( + machine *Machine, + namespace *Namespace, +) error { + if machine.NamespaceID == namespace.ID { + return errSameNamespace } sharedMachines := []SharedMachine{} - if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil { + if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil { return err } if len(sharedMachines) > 0 { - return errorMachineAlreadyShared + return errMachineAlreadyShared } sharedMachine := SharedMachine{ - MachineID: m.ID, - Machine: *m, - NamespaceID: ns.ID, - Namespace: *ns, + MachineID: machine.ID, + Machine: *machine, + NamespaceID: namespace.ID, + Namespace: *namespace, } h.db.Save(&sharedMachine) return nil } -// RemoveSharedMachineFromNamespace removes a shared machine from a namespace -func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error { - if m.NamespaceID == ns.ID { +// RemoveSharedMachineFromNamespace removes a shared machine from a namespace. +func (h *Headscale) RemoveSharedMachineFromNamespace( + machine *Machine, + namespace *Namespace, +) error { + if machine.NamespaceID == namespace.ID { // Can't unshare from primary namespace - return errorMachineNotShared + return errMachineNotShared } sharedMachine := SharedMachine{} - result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Unscoped().Delete(&sharedMachine) + result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID). + Unscoped(). + Delete(&sharedMachine) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { - return errorMachineNotShared + return errMachineNotShared } - err := h.RequestMapUpdates(ns.ID) + err := h.RequestMapUpdates(namespace.ID) if err != nil { return err } @@ -65,10 +75,10 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) return nil } -// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces -func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error { +// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces. +func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error { sharedMachine := SharedMachine{} - if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil { + if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil { return result.Error } diff --git a/sharing_test.go b/sharing_test.go index 4d9e409..7ec1b0e 100644 --- a/sharing_test.go +++ b/sharing_test.go @@ -4,45 +4,48 @@ import ( "gopkg.in/check.v1" ) -func CreateNodeNamespace(c *check.C, namespace, node, key, IP string) (*Namespace, *Machine) { - n1, err := h.CreateNamespace(namespace) +func CreateNodeNamespace( + c *check.C, + namespaceName, node, key, ip string, +) (*Namespace, *Machine) { + namespace, err := app.CreateNamespace(namespaceName) c.Assert(err, check.IsNil) - pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) + pak1, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine(n1.Name, node) + _, err = app.GetMachine(namespace.Name, node) c.Assert(err, check.NotNil) - m1 := &Machine{ + machine := &Machine{ ID: 0, MachineKey: key, NodeKey: key, DiscoKey: key, Name: node, - NamespaceID: n1.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", - IPAddress: IP, + IPAddress: ip, AuthKeyID: uint(pak1.ID), } - h.db.Save(m1) + app.db.Save(machine) - _, err = h.GetMachine(n1.Name, m1.Name) + _, err = app.GetMachine(namespace.Name, machine.Name) c.Assert(err, check.IsNil) - return n1, m1 + return namespace, machine } func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) { - n1, m1 := CreateNodeNamespace( + namespace1, machine1 := CreateNodeNamespace( c, "shared1", "test_get_shared_nodes_1", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "100.64.0.1", ) - _, m2 := CreateNodeNamespace( + _, machine2 := CreateNodeNamespace( c, "shared2", "test_get_shared_nodes_2", @@ -50,21 +53,21 @@ func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) { "100.64.0.2", ) - p1s, err := h.getPeers(m1) + peersOfMachine1BeforeShared, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 0) + c.Assert(len(peersOfMachine1BeforeShared), check.Equals, 0) - err = h.AddSharedMachineToNamespace(m2, n1) + err = app.AddSharedMachineToNamespace(machine2, namespace1) c.Assert(err, check.IsNil) - p1sAfter, err := h.getPeers(m1) + peersOfMachine1AfterShared, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1sAfter), check.Equals, 1) - c.Assert(p1sAfter[0].ID, check.Equals, m2.ID) + c.Assert(len(peersOfMachine1AfterShared), check.Equals, 1) + c.Assert(peersOfMachine1AfterShared[0].ID, check.Equals, machine2.ID) } func (s *Suite) TestSameNamespace(c *check.C) { - n1, m1 := CreateNodeNamespace( + namespace1, machine1 := CreateNodeNamespace( c, "shared1", "test_get_shared_nodes_1", @@ -72,23 +75,23 @@ func (s *Suite) TestSameNamespace(c *check.C) { "100.64.0.1", ) - p1s, err := h.getPeers(m1) + peersOfMachine1BeforeShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 0) + c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0) - err = h.AddSharedMachineToNamespace(m1, n1) - c.Assert(err, check.Equals, errorSameNamespace) + err = app.AddSharedMachineToNamespace(machine1, namespace1) + c.Assert(err, check.Equals, errSameNamespace) } func (s *Suite) TestUnshare(c *check.C) { - n1, m1 := CreateNodeNamespace( + namespace1, machine1 := CreateNodeNamespace( c, "shared1", "test_unshare_1", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "100.64.0.1", ) - _, m2 := CreateNodeNamespace( + _, machine2 := CreateNodeNamespace( c, "shared2", "test_unshare_2", @@ -96,40 +99,40 @@ func (s *Suite) TestUnshare(c *check.C) { "100.64.0.2", ) - p1s, err := h.getPeers(m1) + peersOfMachine1BeforeShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 0) + c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0) - err = h.AddSharedMachineToNamespace(m2, n1) + err = app.AddSharedMachineToNamespace(machine2, namespace1) c.Assert(err, check.IsNil) - p1s, err = h.getShared(m1) + peersOfMachine1BeforeShare, err = app.getShared(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 1) + c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) - err = h.RemoveSharedMachineFromNamespace(m2, n1) + err = app.RemoveSharedMachineFromNamespace(machine2, namespace1) c.Assert(err, check.IsNil) - p1s, err = h.getShared(m1) + peersOfMachine1BeforeShare, err = app.getShared(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 0) + c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0) - err = h.RemoveSharedMachineFromNamespace(m2, n1) - c.Assert(err, check.Equals, errorMachineNotShared) + err = app.RemoveSharedMachineFromNamespace(machine2, namespace1) + c.Assert(err, check.Equals, errMachineNotShared) - err = h.RemoveSharedMachineFromNamespace(m1, n1) - c.Assert(err, check.Equals, errorMachineNotShared) + err = app.RemoveSharedMachineFromNamespace(machine1, namespace1) + c.Assert(err, check.Equals, errMachineNotShared) } func (s *Suite) TestAlreadyShared(c *check.C) { - n1, m1 := CreateNodeNamespace( + namespace1, machine1 := CreateNodeNamespace( c, "shared1", "test_get_shared_nodes_1", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "100.64.0.1", ) - _, m2 := CreateNodeNamespace( + _, machine2 := CreateNodeNamespace( c, "shared2", "test_get_shared_nodes_2", @@ -137,25 +140,25 @@ func (s *Suite) TestAlreadyShared(c *check.C) { "100.64.0.2", ) - p1s, err := h.getPeers(m1) + peersOfMachine1BeforeShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 0) + c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0) - err = h.AddSharedMachineToNamespace(m2, n1) + err = app.AddSharedMachineToNamespace(machine2, namespace1) c.Assert(err, check.IsNil) - err = h.AddSharedMachineToNamespace(m2, n1) - c.Assert(err, check.Equals, errorMachineAlreadyShared) + err = app.AddSharedMachineToNamespace(machine2, namespace1) + c.Assert(err, check.Equals, errMachineAlreadyShared) } func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) { - n1, m1 := CreateNodeNamespace( + namespace1, machine1 := CreateNodeNamespace( c, "shared1", "test_get_shared_nodes_1", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "100.64.0.1", ) - _, m2 := CreateNodeNamespace( + _, machine2 := CreateNodeNamespace( c, "shared2", "test_get_shared_nodes_2", @@ -163,35 +166,35 @@ func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) { "100.64.0.2", ) - p1s, err := h.getPeers(m1) + peersOfMachine1BeforeShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 0) + c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0) - err = h.AddSharedMachineToNamespace(m2, n1) + err = app.AddSharedMachineToNamespace(machine2, namespace1) c.Assert(err, check.IsNil) - p1sAfter, err := h.getPeers(m1) + peersOfMachine1AfterShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1sAfter), check.Equals, 1) - c.Assert(p1sAfter[0].Name, check.Equals, "test_get_shared_nodes_2") + c.Assert(len(peersOfMachine1AfterShare), check.Equals, 1) + c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, "test_get_shared_nodes_2") } func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { - n1, m1 := CreateNodeNamespace( + namespace1, machine1 := CreateNodeNamespace( c, "shared1", "test_get_shared_nodes_1", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "100.64.0.1", ) - _, m2 := CreateNodeNamespace( + _, machine2 := CreateNodeNamespace( c, "shared2", "test_get_shared_nodes_2", "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", "100.64.0.2", ) - _, m3 := CreateNodeNamespace( + _, machine3 := CreateNodeNamespace( c, "shared3", "test_get_shared_nodes_3", @@ -199,76 +202,80 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { "100.64.0.3", ) - pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil) + pak4, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil) c.Assert(err, check.IsNil) - m4 := &Machine{ + machine4 := &Machine{ ID: 4, MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", Name: "test_get_shared_nodes_4", - NamespaceID: n1.ID, + NamespaceID: namespace1.ID, Registered: true, RegisterMethod: "authKey", IPAddress: "100.64.0.4", AuthKeyID: uint(pak4.ID), } - h.db.Save(m4) + app.db.Save(machine4) - _, err = h.GetMachine(n1.Name, m4.Name) + _, err = app.GetMachine(namespace1.Name, machine4.Name) c.Assert(err, check.IsNil) - p1s, err := h.getPeers(m1) + peersOfMachine1BeforeShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 1) // node1 can see node4 - c.Assert(p1s[0].Name, check.Equals, m4.Name) + c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // node1 can see node4 + c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name) - err = h.AddSharedMachineToNamespace(m2, n1) + err = app.AddSharedMachineToNamespace(machine2, namespace1) c.Assert(err, check.IsNil) - p1sAfter, err := h.getPeers(m1) - c.Assert(err, check.IsNil) - c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace) - c.Assert(p1sAfter[0].Name, check.Equals, m2.Name) - c.Assert(p1sAfter[1].Name, check.Equals, m4.Name) - - node1shared, err := h.getShared(m1) - c.Assert(err, check.IsNil) - c.Assert(len(node1shared), check.Equals, 1) // node1 can see node2 as shared - c.Assert(node1shared[0].Name, check.Equals, m2.Name) - - pAlone, err := h.getPeers(m3) - c.Assert(err, check.IsNil) - c.Assert(len(pAlone), check.Equals, 0) // node3 is alone - - pSharedTo, err := h.getPeers(m2) + peersOfMachine1AfterShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) c.Assert( - len(pSharedTo), + len(peersOfMachine1AfterShare), + check.Equals, + 2, + ) // node1 can see node2 (shared) and node4 (same namespace) + c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name) + c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name) + + sharedOfMachine1, err := app.getShared(machine1) + c.Assert(err, check.IsNil) + c.Assert(len(sharedOfMachine1), check.Equals, 1) // node1 can see node2 as shared + c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name) + + peersOfMachine3, err := app.getPeers(machine3) + c.Assert(err, check.IsNil) + c.Assert(len(peersOfMachine3), check.Equals, 0) // node3 is alone + + peersOfMachine2, err := app.getPeers(machine2) + c.Assert(err, check.IsNil) + c.Assert( + len(peersOfMachine2), check.Equals, 2, ) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1 - c.Assert(pSharedTo[0].Name, check.Equals, m1.Name) - c.Assert(pSharedTo[1].Name, check.Equals, m4.Name) + c.Assert(peersOfMachine2[0].Name, check.Equals, machine1.Name) + c.Assert(peersOfMachine2[1].Name, check.Equals, machine4.Name) } func (s *Suite) TestDeleteSharedMachine(c *check.C) { - n1, m1 := CreateNodeNamespace( + namespace1, machine1 := CreateNodeNamespace( c, "shared1", "test_get_shared_nodes_1", "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", "100.64.0.1", ) - _, m2 := CreateNodeNamespace( + _, machine2 := CreateNodeNamespace( c, "shared2", "test_get_shared_nodes_2", "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", "100.64.0.2", ) - _, m3 := CreateNodeNamespace( + _, machine3 := CreateNodeNamespace( c, "shared3", "test_get_shared_nodes_3", @@ -276,56 +283,58 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) { "100.64.0.3", ) - pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) + pak4n1, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil) c.Assert(err, check.IsNil) - m4 := &Machine{ + machine4 := &Machine{ ID: 4, MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f", Name: "test_get_shared_nodes_4", - NamespaceID: n1.ID, + NamespaceID: namespace1.ID, Registered: true, RegisterMethod: "authKey", IPAddress: "100.64.0.4", AuthKeyID: uint(pak4n1.ID), } - h.db.Save(m4) + app.db.Save(machine4) - _, err = h.GetMachine(n1.Name, m4.Name) + _, err = app.GetMachine(namespace1.Name, machine4.Name) c.Assert(err, check.IsNil) - p1s, err := h.getPeers(m1) + peersOfMachine1BeforeShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4 - c.Assert(p1s[0].Name, check.Equals, m4.Name) + c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // nodes 1 and 4 + c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name) - err = h.AddSharedMachineToNamespace(m2, n1) + err = app.AddSharedMachineToNamespace(machine2, namespace1) c.Assert(err, check.IsNil) - p1sAfter, err := h.getPeers(m1) + peersOfMachine1AfterShare, err := app.getPeers(machine1) c.Assert(err, check.IsNil) - c.Assert(len(p1sAfter), check.Equals, 2) // nodes 1, 2, 4 - c.Assert(p1sAfter[0].Name, check.Equals, m2.Name) - c.Assert(p1sAfter[1].Name, check.Equals, m4.Name) + c.Assert(len(peersOfMachine1AfterShare), check.Equals, 2) // nodes 1, 2, 4 + c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name) + c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name) - node1shared, err := h.getShared(m1) + sharedOfMachine1, err := app.getShared(machine1) c.Assert(err, check.IsNil) - c.Assert(len(node1shared), check.Equals, 1) // nodes 1, 2, 4 - c.Assert(node1shared[0].Name, check.Equals, m2.Name) + c.Assert(len(sharedOfMachine1), check.Equals, 1) // nodes 1, 2, 4 + c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name) - pAlone, err := h.getPeers(m3) + peersOfMachine3, err := app.getPeers(machine3) c.Assert(err, check.IsNil) - c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone + c.Assert(len(peersOfMachine3), check.Equals, 0) // node 3 is alone - sharedMachines, err := h.ListSharedMachinesInNamespace(n1.Name) + sharedMachinesInNamespace1, err := app.ListSharedMachinesInNamespace( + namespace1.Name, + ) c.Assert(err, check.IsNil) - c.Assert(len(sharedMachines), check.Equals, 1) + c.Assert(len(sharedMachinesInNamespace1), check.Equals, 1) - err = h.DeleteMachine(m2) + err = app.DeleteMachine(machine2) c.Assert(err, check.IsNil) - sharedMachines, err = h.ListSharedMachinesInNamespace(n1.Name) + sharedMachinesInNamespace1, err = app.ListSharedMachinesInNamespace(namespace1.Name) c.Assert(err, check.IsNil) - c.Assert(len(sharedMachines), check.Equals, 0) + c.Assert(len(sharedMachinesInNamespace1), check.Equals, 0) } diff --git a/swagger.go b/swagger.go index 17f5769..9e62d39 100644 --- a/swagger.go +++ b/swagger.go @@ -6,16 +6,15 @@ import ( "net/http" "text/template" - "github.com/rs/zerolog/log" - "github.com/gin-gonic/gin" + "github.com/rs/zerolog/log" ) //go:embed gen/openapiv2/headscale/v1/headscale.swagger.json var apiV1JSON []byte -func SwaggerUI(c *gin.Context) { - t := template.Must(template.New("swagger").Parse(` +func SwaggerUI(ctx *gin.Context) { + swaggerTemplate := template.Must(template.New("swagger").Parse(` @@ -48,18 +47,23 @@ func SwaggerUI(c *gin.Context) { `)) var payload bytes.Buffer - if err := t.Execute(&payload, struct{}{}); err != nil { + if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil { log.Error(). Caller(). Err(err). Msg("Could not render Swagger") - c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Swagger")) + ctx.Data( + http.StatusInternalServerError, + "text/html; charset=utf-8", + []byte("Could not render Swagger"), + ) + return } - c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) + ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes()) } -func SwaggerAPIv1(c *gin.Context) { - c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON) +func SwaggerAPIv1(ctx *gin.Context) { + ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON) } diff --git a/utils.go b/utils.go index ad0a72d..9f7849e 100644 --- a/utils.go +++ b/utils.go @@ -20,31 +20,48 @@ import ( "tailscale.com/types/wgkey" ) +const ( + errCannotDecryptReponse = Error("cannot decrypt response") + errResponseMissingNonce = Error("response missing nonce") + errCouldNotAllocateIP = Error("could not find any suitable IP") +) + // Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors type Error string func (e Error) Error() string { return string(e) } -func decode(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error { +func decode( + msg []byte, + v interface{}, + pubKey *wgkey.Key, + privKey *wgkey.Private, +) error { return decodeMsg(msg, v, pubKey, privKey) } -func decodeMsg(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error { +func decodeMsg( + msg []byte, + output interface{}, + pubKey *wgkey.Key, + privKey *wgkey.Private, +) error { decrypted, err := decryptMsg(msg, pubKey, privKey) if err != nil { return err } // fmt.Println(string(decrypted)) - if err := json.Unmarshal(decrypted, v); err != nil { - return fmt.Errorf("response: %v", err) + if err := json.Unmarshal(decrypted, output); err != nil { + return err } + return nil } func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) { var nonce [24]byte if len(msg) < len(nonce)+1 { - return nil, fmt.Errorf("response missing nonce, len=%d", len(msg)) + return nil, errResponseMissingNonce } copy(nonce[:], msg) msg = msg[len(nonce):] @@ -52,8 +69,9 @@ func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) decrypted, ok := box.Open(nil, msg, &nonce, pub, pri) if !ok { - return nil, fmt.Errorf("cannot decrypt response") + return nil, errCannotDecryptReponse } + return decrypted, nil } @@ -66,13 +84,18 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e return encodeMsg(b, pubKey, privKey) } -func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) { +func encodeMsg( + payload []byte, + pubKey *wgkey.Key, + privKey *wgkey.Private, +) ([]byte, error) { var nonce [24]byte if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { panic(err) } pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey) - msg := box.Seal(nonce[:], b, &nonce, pub, pri) + msg := box.Seal(nonce[:], payload, &nonce, pub, pri) + return msg, nil } @@ -89,7 +112,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { for { if !ipPrefix.Contains(ip) { - return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix) + return nil, errCouldNotAllocateIP } // Some OS (including Linux) does not like when IPs ends with 0 or 255, which @@ -98,13 +121,14 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) { ipRaw := ip.As4() if ipRaw[3] == 0 || ipRaw[3] == 255 { ip = ip.Next() + continue } if ip.IsZero() && ip.IsLoopback() { - ip = ip.Next() + continue } @@ -125,7 +149,7 @@ func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) { if addr != "" { ip, err := netaddr.ParseIP(addr) if err != nil { - return nil, fmt.Errorf("failed to parse ip from database, %w", err) + return nil, fmt.Errorf("failed to parse ip from database: %w", err) } ips[index] = ip @@ -156,11 +180,16 @@ func tailNodesToString(nodes []*tailcfg.Node) string { } func tailMapResponseToString(resp tailcfg.MapResponse) string { - return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers)) + return fmt.Sprintf( + "{ Node: %s, Peers: %s }", + resp.Node.Name, + tailNodesToString(resp.Peers), + ) } func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { var d net.Dialer + return d.DialContext(ctx, "unix", addr) } @@ -174,7 +203,7 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string { return result } -func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { +func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { result := make([]netaddr.IPPrefix, len(prefixes)) for index, prefixStr := range prefixes { @@ -189,7 +218,7 @@ func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) { return result, nil } -func containsIpPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool { +func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool { for _, p := range prefixes { if prefix == p { return true diff --git a/utils_test.go b/utils_test.go index f50cd11..dcda613 100644 --- a/utils_test.go +++ b/utils_test.go @@ -6,7 +6,7 @@ import ( ) func (s *Suite) TestGetAvailableIp(c *check.C) { - ip, err := h.getAvailableIP() + ip, err := app.getAvailableIP() c.Assert(err, check.IsNil) @@ -16,33 +16,33 @@ func (s *Suite) TestGetAvailableIp(c *check.C) { } func (s *Suite) TestGetUsedIps(c *check.C) { - ip, err := h.getAvailableIP() + ip, err := app.getAvailableIP() c.Assert(err, check.IsNil) - n, err := h.CreateNamespace("test_ip") + namespace, err := app.CreateNamespace("test_ip") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine("test", "testmachine") + _, err = app.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), IPAddress: ip.String(), } - h.db.Save(&m) + app.db.Save(&machine) - ips, err := h.getUsedIPs() + ips, err := app.getUsedIPs() c.Assert(err, check.IsNil) @@ -50,42 +50,42 @@ func (s *Suite) TestGetUsedIps(c *check.C) { c.Assert(ips[0], check.Equals, expected) - m1, err := h.GetMachineByID(0) + machine1, err := app.GetMachineByID(0) c.Assert(err, check.IsNil) - c.Assert(m1.IPAddress, check.Equals, expected.String()) + c.Assert(machine1.IPAddress, check.Equals, expected.String()) } func (s *Suite) TestGetMultiIp(c *check.C) { - n, err := h.CreateNamespace("test-ip-multi") + namespace, err := app.CreateNamespace("test-ip-multi") c.Assert(err, check.IsNil) - for i := 1; i <= 350; i++ { - ip, err := h.getAvailableIP() + for index := 1; index <= 350; index++ { + ip, err := app.getAvailableIP() c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine("test", "testmachine") + _, err = app.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - m := Machine{ - ID: uint64(i), + machine := Machine{ + ID: uint64(index), MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), IPAddress: ip.String(), } - h.db.Save(&m) + app.db.Save(&machine) } - ips, err := h.getUsedIPs() + ips, err := app.getUsedIPs() c.Assert(err, check.IsNil) @@ -96,59 +96,67 @@ func (s *Suite) TestGetMultiIp(c *check.C) { c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47")) // Check that we can read back the IPs - m1, err := h.GetMachineByID(1) + machine1, err := app.GetMachineByID(1) c.Assert(err, check.IsNil) - c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String()) + c.Assert( + machine1.IPAddress, + check.Equals, + netaddr.MustParseIP("10.27.0.1").String(), + ) - m50, err := h.GetMachineByID(50) + machine50, err := app.GetMachineByID(50) c.Assert(err, check.IsNil) - c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String()) + c.Assert( + machine50.IPAddress, + check.Equals, + netaddr.MustParseIP("10.27.0.50").String(), + ) expectedNextIP := netaddr.MustParseIP("10.27.1.97") - nextIP, err := h.getAvailableIP() + nextIP, err := app.getAvailableIP() c.Assert(err, check.IsNil) c.Assert(nextIP.String(), check.Equals, expectedNextIP.String()) // If we call get Available again, we should receive // the same IP, as it has not been reserved. - nextIP2, err := h.getAvailableIP() + nextIP2, err := app.getAvailableIP() c.Assert(err, check.IsNil) c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String()) } func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { - ip, err := h.getAvailableIP() + ip, err := app.getAvailableIP() c.Assert(err, check.IsNil) expected := netaddr.MustParseIP("10.27.0.1") c.Assert(ip.String(), check.Equals, expected.String()) - n, err := h.CreateNamespace("test_ip") + namespace, err := app.CreateNamespace("test_ip") c.Assert(err, check.IsNil) - pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil) c.Assert(err, check.IsNil) - _, err = h.GetMachine("test", "testmachine") + _, err = app.GetMachine("test", "testmachine") c.Assert(err, check.NotNil) - m := Machine{ + machine := Machine{ ID: 0, MachineKey: "foo", NodeKey: "bar", DiscoKey: "faa", Name: "testmachine", - NamespaceID: n.ID, + NamespaceID: namespace.ID, Registered: true, RegisterMethod: "authKey", AuthKeyID: uint(pak.ID), } - h.db.Save(&m) + app.db.Save(&machine) - ip2, err := h.getAvailableIP() + ip2, err := app.getAvailableIP() c.Assert(err, check.IsNil) c.Assert(ip2.String(), check.Equals, expected.String())