diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index ca4d4cf..963663a 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -18,12 +18,11 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v2
with:
- go-version: "1.16.3"
+ go-version: "1.17.3"
- name: Install dependencies
run: |
go version
- go install golang.org/x/lint/golint@latest
sudo apt update
sudo apt install -y make
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 6b561d2..b3c6400 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -1,20 +1,37 @@
+---
name: CI
on: [push, pull_request]
jobs:
- # The "build" workflow
- lint:
- # The type of runner that the job will run on
+ golangci-lint:
runs-on: ubuntu-latest
-
- # Steps represent a sequence of tasks that will be executed as part of the job
steps:
- # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
- # Install and run golangci-lint as a separate step, it's much faster this
- # way because this action has caching. It'll get run again in `make lint`
- # below, but it's still much faster in the end than installing
- # golangci-lint manually in the `Run lint` step.
- - uses: golangci/golangci-lint-action@v2
+ - name: golangci-lint
+ uses: golangci/golangci-lint-action@v2
+ with:
+ version: latest
+
+ prettier-lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Prettify code
+ uses: creyD/prettier_action@v4.0
+ with:
+ prettier_options: >-
+ --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html}
+ only_changed: false
+ dry: true
+
+ proto-lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - uses: bufbuild/buf-setup-action@v0.7.0
+ - uses: bufbuild/buf-lint-action@v1
+ with:
+ input: "proto"
diff --git a/.golangci.yaml b/.golangci.yaml
index a97c2bb..2defdc8 100644
--- a/.golangci.yaml
+++ b/.golangci.yaml
@@ -1,7 +1,53 @@
---
run:
- timeout: 5m
+ timeout: 10m
issues:
skip-dirs:
- gen
+linters:
+ enable-all: true
+ disable:
+ - exhaustivestruct
+ - revive
+ - lll
+ - interfacer
+ - scopelint
+ - maligned
+ - golint
+ - gofmt
+ - gochecknoglobals
+ - gochecknoinits
+ - gocognit
+ - funlen
+ - exhaustivestruct
+ - tagliatelle
+ - godox
+ - ireturn
+
+ # In progress
+ - gocritic
+
+ # We should strive to enable these:
+ - wrapcheck
+ - dupl
+ - makezero
+
+ # We might want to enable this, but it might be a lot of work
+ - cyclop
+ - nestif
+ - wsl # might be incompatible with gofumpt
+ - testpackage
+ - paralleltest
+
+linters-settings:
+ varnamelen:
+ ignore-type-assert-ok: true
+ ignore-map-index-ok: true
+ ignore-names:
+ - err
+ - db
+ - id
+ - ip
+ - ok
+ - c
diff --git a/Makefile b/Makefile
index 92beaef..060d3b9 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,14 @@
# Calculate version
version = $(shell ./scripts/version-at-commit.sh)
+rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d))
+
+# GO_SOURCES = $(wildcard *.go)
+# PROTO_SOURCES = $(wildcard **/*.proto)
+GO_SOURCES = $(call rwildcard,,*.go)
+PROTO_SOURCES = $(call rwildcard,,*.proto)
+
+
build:
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$(version)" cmd/headscale/headscale.go
@@ -19,7 +27,12 @@ coverprofile_html:
go tool cover -html=coverage.out
lint:
- golangci-lint run --fix
+ golangci-lint run --fix --timeout 10m
+
+fmt:
+ prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}'
+ golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES)
+ clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i $(PROTO_SOURCES)
proto-lint:
cd proto/ && buf lint
diff --git a/README.md b/README.md
index e9f7990..41f388e 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,6 @@ Suggestions/PRs welcomed!
Please have a look at the documentation under [`docs/`](docs/).
-
## Disclaimer
1. We have nothing to do with Tailscale, or Tailscale Inc.
@@ -64,13 +63,30 @@ Please have a look at the documentation under [`docs/`](docs/).
To contribute to Headscale you would need the lastest version of [Go](https://golang.org) and [Buf](https://buf.build)(Protobuf generator).
+### Code style
+
+To ensure we have some consistency with a growing number of contributes, this project has adopted linting and style/formatting rules:
+
+The **Go** code is linted with [`golangci-lint`](https://golangci-lint.run) and
+formatted with [`golines`](https://github.com/segmentio/golines) (width 88) and
+[`gofumpt`](https://github.com/mvdan/gofumpt).
+Please configure your editor to run the tools while developing and make sure to
+run `make lint` and `make fmt` before committing any code.
+
+The **Proto** code is linted with [`buf`](https://docs.buf.build/lint/overview) and
+formatted with [`clang-format`](https://clang.llvm.org/docs/ClangFormat.html).
+
+The **rest** (markdown, yaml, etc) is formatted with [`prettier`](https://prettier.io).
+
+Check out the `.golangci.yaml` and `Makefile` to see the specific configuration.
+
### Install development tools
- Go
- Buf
- Protobuf tools:
-```shell
+```shell
make install-protobuf-plugins
```
@@ -81,6 +97,7 @@ Some parts of the project requires the generation of Go code from Protobuf (if c
```shell
make generate
```
+
**Note**: Please check in changes from `gen/` in a separate commit to make it easier to review.
To run the tests:
@@ -261,5 +278,3 @@ make build
-
-
diff --git a/acls.go b/acls.go
index cb02e04..1550c34 100644
--- a/acls.go
+++ b/acls.go
@@ -9,23 +9,30 @@ import (
"strings"
"github.com/rs/zerolog/log"
-
"github.com/tailscale/hujson"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)
const (
- errorEmptyPolicy = Error("empty policy")
- errorInvalidAction = Error("invalid action")
- errorInvalidUserSection = Error("invalid user section")
- errorInvalidGroup = Error("invalid group")
- errorInvalidTag = Error("invalid tag")
- errorInvalidNamespace = Error("invalid namespace")
- errorInvalidPortFormat = Error("invalid port format")
+ errEmptyPolicy = Error("empty policy")
+ errInvalidAction = Error("invalid action")
+ errInvalidUserSection = Error("invalid user section")
+ errInvalidGroup = Error("invalid group")
+ errInvalidTag = Error("invalid tag")
+ errInvalidNamespace = Error("invalid namespace")
+ errInvalidPortFormat = Error("invalid port format")
)
-// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules
+const (
+ Base10 = 10
+ BitSize16 = 16
+ portRangeBegin = 0
+ portRangeEnd = 65535
+ expectedTokenItems = 2
+)
+
+// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules.
func (h *Headscale) LoadACLPolicy(path string) error {
policyFile, err := os.Open(path)
if err != nil {
@@ -34,23 +41,23 @@ func (h *Headscale) LoadACLPolicy(path string) error {
defer policyFile.Close()
var policy ACLPolicy
- b, err := io.ReadAll(policyFile)
+ policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return err
}
- ast, err := hujson.Parse(b)
+ ast, err := hujson.Parse(policyBytes)
if err != nil {
return err
}
ast.Standardize()
- b = ast.Pack()
- err = json.Unmarshal(b, &policy)
+ policyBytes = ast.Pack()
+ err = json.Unmarshal(policyBytes, &policy)
if err != nil {
return err
}
if policy.IsZero() {
- return errorEmptyPolicy
+ return errEmptyPolicy
}
h.aclPolicy = &policy
@@ -59,37 +66,40 @@ func (h *Headscale) LoadACLPolicy(path string) error {
return err
}
h.aclRules = rules
+
return nil
}
func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
- for i, a := range h.aclPolicy.ACLs {
- if a.Action != "accept" {
- return nil, errorInvalidAction
+ for index, acl := range h.aclPolicy.ACLs {
+ if acl.Action != "accept" {
+ return nil, errInvalidAction
}
- r := tailcfg.FilterRule{}
+ filterRule := tailcfg.FilterRule{}
srcIPs := []string{}
- for j, u := range a.Users {
- srcs, err := h.generateACLPolicySrcIP(u)
+ for innerIndex, user := range acl.Users {
+ srcs, err := h.generateACLPolicySrcIP(user)
if err != nil {
log.Error().
- Msgf("Error parsing ACL %d, User %d", i, j)
+ Msgf("Error parsing ACL %d, User %d", index, innerIndex)
+
return nil, err
}
srcIPs = append(srcIPs, srcs...)
}
- r.SrcIPs = srcIPs
+ filterRule.SrcIPs = srcIPs
destPorts := []tailcfg.NetPortRange{}
- for j, d := range a.Ports {
- dests, err := h.generateACLPolicyDestPorts(d)
+ for innerIndex, ports := range acl.Ports {
+ dests, err := h.generateACLPolicyDestPorts(ports)
if err != nil {
log.Error().
- Msgf("Error parsing ACL %d, Port %d", i, j)
+ Msgf("Error parsing ACL %d, Port %d", index, innerIndex)
+
return nil, err
}
destPorts = append(destPorts, dests...)
@@ -108,10 +118,12 @@ func (h *Headscale) generateACLPolicySrcIP(u string) ([]string, error) {
return h.expandAlias(u)
}
-func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange, error) {
+func (h *Headscale) generateACLPolicyDestPorts(
+ d string,
+) ([]tailcfg.NetPortRange, error) {
tokens := strings.Split(d, ":")
- if len(tokens) < 2 || len(tokens) > 3 {
- return nil, errorInvalidPortFormat
+ if len(tokens) < expectedTokenItems || len(tokens) > 3 {
+ return nil, errInvalidPortFormat
}
var alias string
@@ -121,7 +133,7 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
- if len(tokens) == 2 {
+ if len(tokens) == expectedTokenItems {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
@@ -146,34 +158,36 @@ func (h *Headscale) generateACLPolicyDestPorts(d string) ([]tailcfg.NetPortRange
dests = append(dests, pr)
}
}
+
return dests, nil
}
-func (h *Headscale) expandAlias(s string) ([]string, error) {
- if s == "*" {
+func (h *Headscale) expandAlias(alias string) ([]string, error) {
+ if alias == "*" {
return []string{"*"}, nil
}
- if strings.HasPrefix(s, "group:") {
- if _, ok := h.aclPolicy.Groups[s]; !ok {
- return nil, errorInvalidGroup
+ if strings.HasPrefix(alias, "group:") {
+ if _, ok := h.aclPolicy.Groups[alias]; !ok {
+ return nil, errInvalidGroup
}
ips := []string{}
- for _, n := range h.aclPolicy.Groups[s] {
+ for _, n := range h.aclPolicy.Groups[alias] {
nodes, err := h.ListMachinesInNamespace(n)
if err != nil {
- return nil, errorInvalidNamespace
+ return nil, errInvalidNamespace
}
for _, node := range nodes {
ips = append(ips, node.IPAddress)
}
}
+
return ips, nil
}
- if strings.HasPrefix(s, "tag:") {
- if _, ok := h.aclPolicy.TagOwners[s]; !ok {
- return nil, errorInvalidTag
+ if strings.HasPrefix(alias, "tag:") {
+ if _, ok := h.aclPolicy.TagOwners[alias]; !ok {
+ return nil, errInvalidTag
}
// This will have HORRIBLE performance.
@@ -183,10 +197,10 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
return nil, err
}
ips := []string{}
- for _, m := range machines {
+ for _, machine := range machines {
hostinfo := tailcfg.Hostinfo{}
- if len(m.HostInfo) != 0 {
- hi, err := m.HostInfo.MarshalJSON()
+ if len(machine.HostInfo) != 0 {
+ hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@@ -197,17 +211,19 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
// FIXME: Check TagOwners allows this
for _, t := range hostinfo.RequestTags {
- if s[4:] == t {
- ips = append(ips, m.IPAddress)
+ if alias[4:] == t {
+ ips = append(ips, machine.IPAddress)
+
break
}
}
}
}
+
return ips, nil
}
- n, err := h.GetNamespace(s)
+ n, err := h.GetNamespace(alias)
if err == nil {
nodes, err := h.ListMachinesInNamespace(n.Name)
if err != nil {
@@ -217,49 +233,54 @@ func (h *Headscale) expandAlias(s string) ([]string, error) {
for _, n := range nodes {
ips = append(ips, n.IPAddress)
}
+
return ips, nil
}
- if h, ok := h.aclPolicy.Hosts[s]; ok {
+ if h, ok := h.aclPolicy.Hosts[alias]; ok {
return []string{h.String()}, nil
}
- ip, err := netaddr.ParseIP(s)
+ ip, err := netaddr.ParseIP(alias)
if err == nil {
return []string{ip.String()}, nil
}
- cidr, err := netaddr.ParseIPPrefix(s)
+ cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil {
return []string{cidr.String()}, nil
}
- return nil, errorInvalidUserSection
+ return nil, errInvalidUserSection
}
-func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
- if s == "*" {
- return &[]tailcfg.PortRange{{First: 0, Last: 65535}}, nil
+func (h *Headscale) expandPorts(portsStr string) (*[]tailcfg.PortRange, error) {
+ if portsStr == "*" {
+ return &[]tailcfg.PortRange{
+ {First: portRangeBegin, Last: portRangeEnd},
+ }, nil
}
ports := []tailcfg.PortRange{}
- for _, p := range strings.Split(s, ",") {
- rang := strings.Split(p, "-")
- if len(rang) == 1 {
- pi, err := strconv.ParseUint(rang[0], 10, 16)
+ for _, portStr := range strings.Split(portsStr, ",") {
+ rang := strings.Split(portStr, "-")
+ switch len(rang) {
+ case 1:
+ port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
- First: uint16(pi),
- Last: uint16(pi),
+ First: uint16(port),
+ Last: uint16(port),
})
- } else if len(rang) == 2 {
- start, err := strconv.ParseUint(rang[0], 10, 16)
+
+ case expectedTokenItems:
+ start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil {
return nil, err
}
- last, err := strconv.ParseUint(rang[1], 10, 16)
+ last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
if err != nil {
return nil, err
}
@@ -267,9 +288,11 @@ func (h *Headscale) expandPorts(s string) (*[]tailcfg.PortRange, error) {
First: uint16(start),
Last: uint16(last),
})
- } else {
- return nil, errorInvalidPortFormat
+
+ default:
+ return nil, errInvalidPortFormat
}
}
+
return &ports, nil
}
diff --git a/acls_test.go b/acls_test.go
index da7f3ec..3e051f5 100644
--- a/acls_test.go
+++ b/acls_test.go
@@ -5,54 +5,58 @@ import (
)
func (s *Suite) TestWrongPath(c *check.C) {
- err := h.LoadACLPolicy("asdfg")
+ err := app.LoadACLPolicy("asdfg")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestBrokenHuJson(c *check.C) {
- err := h.LoadACLPolicy("./tests/acls/broken.hujson")
+ err := app.LoadACLPolicy("./tests/acls/broken.hujson")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestInvalidPolicyHuson(c *check.C) {
- err := h.LoadACLPolicy("./tests/acls/invalid.hujson")
+ err := app.LoadACLPolicy("./tests/acls/invalid.hujson")
c.Assert(err, check.NotNil)
- c.Assert(err, check.Equals, errorEmptyPolicy)
+ c.Assert(err, check.Equals, errEmptyPolicy)
}
func (s *Suite) TestParseHosts(c *check.C) {
- var hs Hosts
- err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`))
- c.Assert(hs, check.NotNil)
+ var hosts Hosts
+ err := hosts.UnmarshalJSON(
+ []byte(
+ `{"example-host-1": "100.100.100.100","example-host-2": "100.100.101.100/24"}`,
+ ),
+ )
+ c.Assert(hosts, check.NotNil)
c.Assert(err, check.IsNil)
}
func (s *Suite) TestParseInvalidCIDR(c *check.C) {
- var hs Hosts
- err := hs.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
- c.Assert(hs, check.IsNil)
+ var hosts Hosts
+ err := hosts.UnmarshalJSON([]byte(`{"example-host-1": "100.100.100.100/42"}`))
+ c.Assert(hosts, check.IsNil)
c.Assert(err, check.NotNil)
}
func (s *Suite) TestRuleInvalidGeneration(c *check.C) {
- err := h.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
+ err := app.LoadACLPolicy("./tests/acls/acl_policy_invalid.hujson")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestBasicRule(c *check.C) {
- err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
+ err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_1.hujson")
c.Assert(err, check.IsNil)
- rules, err := h.generateACLRules()
+ rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
}
func (s *Suite) TestPortRange(c *check.C) {
- err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
+ err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_range.hujson")
c.Assert(err, check.IsNil)
- rules, err := h.generateACLRules()
+ rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@@ -63,10 +67,10 @@ func (s *Suite) TestPortRange(c *check.C) {
}
func (s *Suite) TestPortWildcard(c *check.C) {
- err := h.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
+ err := app.LoadACLPolicy("./tests/acls/acl_policy_basic_wildcards.hujson")
c.Assert(err, check.IsNil)
- rules, err := h.generateACLRules()
+ rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@@ -79,33 +83,35 @@ func (s *Suite) TestPortWildcard(c *check.C) {
}
func (s *Suite) TestPortNamespace(c *check.C) {
- n, err := h.CreateNamespace("testnamespace")
+ namespace, err := app.CreateNamespace("testnamespace")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("testnamespace", "testmachine")
+ _, err = app.GetMachine("testnamespace", "testmachine")
c.Assert(err, check.NotNil)
- ip, _ := h.getAvailableIP()
- m := Machine{
+ ip, _ := app.getAvailableIP()
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: ip.String(),
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_namespace_as_user.hujson")
+ err = app.LoadACLPolicy(
+ "./tests/acls/acl_policy_basic_namespace_as_user.hujson",
+ )
c.Assert(err, check.IsNil)
- rules, err := h.generateACLRules()
+ rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
@@ -119,33 +125,33 @@ func (s *Suite) TestPortNamespace(c *check.C) {
}
func (s *Suite) TestPortGroup(c *check.C) {
- n, err := h.CreateNamespace("testnamespace")
+ namespace, err := app.CreateNamespace("testnamespace")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("testnamespace", "testmachine")
+ _, err = app.GetMachine("testnamespace", "testmachine")
c.Assert(err, check.NotNil)
- ip, _ := h.getAvailableIP()
- m := Machine{
+ ip, _ := app.getAvailableIP()
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: ip.String(),
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- err = h.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
+ err = app.LoadACLPolicy("./tests/acls/acl_policy_basic_groups.hujson")
c.Assert(err, check.IsNil)
- rules, err := h.generateACLRules()
+ rules, err := app.generateACLRules()
c.Assert(err, check.IsNil)
c.Assert(rules, check.NotNil)
diff --git a/acls_types.go b/acls_types.go
index 67b74e7..08e650f 100644
--- a/acls_types.go
+++ b/acls_types.go
@@ -8,7 +8,7 @@ import (
"inet.af/netaddr"
)
-// ACLPolicy represents a Tailscale ACL Policy
+// ACLPolicy represents a Tailscale ACL Policy.
type ACLPolicy struct {
Groups Groups `json:"Groups"`
Hosts Hosts `json:"Hosts"`
@@ -17,61 +17,63 @@ type ACLPolicy struct {
Tests []ACLTest `json:"Tests"`
}
-// ACL is a basic rule for the ACL Policy
+// ACL is a basic rule for the ACL Policy.
type ACL struct {
Action string `json:"Action"`
Users []string `json:"Users"`
Ports []string `json:"Ports"`
}
-// Groups references a series of alias in the ACL rules
+// Groups references a series of alias in the ACL rules.
type Groups map[string][]string
-// Hosts are alias for IP addresses or subnets
+// Hosts are alias for IP addresses or subnets.
type Hosts map[string]netaddr.IPPrefix
-// TagOwners specify what users (namespaces?) are allow to use certain tags
+// TagOwners specify what users (namespaces?) are allow to use certain tags.
type TagOwners map[string][]string
-// ACLTest is not implemented, but should be use to check if a certain rule is allowed
+// ACLTest is not implemented, but should be use to check if a certain rule is allowed.
type ACLTest struct {
User string `json:"User"`
Allow []string `json:"Allow"`
Deny []string `json:"Deny,omitempty"`
}
-// UnmarshalJSON allows to parse the Hosts directly into netaddr objects
-func (h *Hosts) UnmarshalJSON(data []byte) error {
- hosts := Hosts{}
- hs := make(map[string]string)
+// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
+func (hosts *Hosts) UnmarshalJSON(data []byte) error {
+ newHosts := Hosts{}
+ hostIPPrefixMap := make(map[string]string)
ast, err := hujson.Parse(data)
if err != nil {
return err
}
ast.Standardize()
data = ast.Pack()
- err = json.Unmarshal(data, &hs)
+ err = json.Unmarshal(data, &hostIPPrefixMap)
if err != nil {
return err
}
- for k, v := range hs {
- if !strings.Contains(v, "/") {
- v = v + "/32"
+ for host, prefixStr := range hostIPPrefixMap {
+ if !strings.Contains(prefixStr, "/") {
+ prefixStr += "/32"
}
- prefix, err := netaddr.ParseIPPrefix(v)
+ prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return err
}
- hosts[k] = prefix
+ newHosts[host] = prefix
}
- *h = hosts
+ *hosts = newHosts
+
return nil
}
-// IsZero is perhaps a bit naive here
-func (p ACLPolicy) IsZero() bool {
- if len(p.Groups) == 0 && len(p.Hosts) == 0 && len(p.ACLs) == 0 {
+// IsZero is perhaps a bit naive here.
+func (policy ACLPolicy) IsZero() bool {
+ if len(policy.Groups) == 0 && len(policy.Hosts) == 0 && len(policy.ACLs) == 0 {
return true
}
+
return false
}
diff --git a/api.go b/api.go
index 490ce25..85c28e3 100644
--- a/api.go
+++ b/api.go
@@ -10,31 +10,37 @@ import (
"strings"
"time"
- "github.com/rs/zerolog/log"
-
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
+ "github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
+const reservedResponseHeaderSize = 4
+
// KeyHandler provides the Headscale pub key
-// Listens in /key
-func (h *Headscale) KeyHandler(c *gin.Context) {
- c.Data(200, "text/plain; charset=utf-8", []byte(h.publicKey.HexString()))
+// Listens in /key.
+func (h *Headscale) KeyHandler(ctx *gin.Context) {
+ ctx.Data(
+ http.StatusOK,
+ "text/plain; charset=utf-8",
+ []byte(h.publicKey.HexString()),
+ )
}
// RegisterWebAPI shows a simple message in the browser to point to the CLI
-// Listens in /register
-func (h *Headscale) RegisterWebAPI(c *gin.Context) {
- mKeyStr := c.Query("key")
- if mKeyStr == "" {
- c.String(http.StatusBadRequest, "Wrong params")
+// Listens in /register.
+func (h *Headscale) RegisterWebAPI(ctx *gin.Context) {
+ machineKeyStr := ctx.Query("key")
+ if machineKeyStr == "" {
+ ctx.String(http.StatusBadRequest, "Wrong params")
+
return
}
- c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
+ ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
headscale
@@ -51,43 +57,45 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
- `, mKeyStr)))
+ `, machineKeyStr)))
}
// RegistrationHandler handles the actual registration process of a machine
-// Endpoint /machine/:id
-func (h *Headscale) RegistrationHandler(c *gin.Context) {
- body, _ := io.ReadAll(c.Request.Body)
- mKeyStr := c.Param("id")
- mKey, err := wgkey.ParseHex(mKeyStr)
+// Endpoint /machine/:id.
+func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
+ body, _ := io.ReadAll(ctx.Request.Body)
+ machineKeyStr := ctx.Param("id")
+ machineKey, err := wgkey.ParseHex(machineKeyStr)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
- c.String(http.StatusInternalServerError, "Sad!")
+ ctx.String(http.StatusInternalServerError, "Sad!")
+
return
}
req := tailcfg.RegisterRequest{}
- err = decode(body, &req, &mKey, h.privateKey)
+ err = decode(body, &req, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
- c.String(http.StatusInternalServerError, "Very sad!")
+ ctx.String(http.StatusInternalServerError, "Very sad!")
+
return
}
now := time.Now().UTC()
- m, err := h.GetMachineByMachineKey(mKey.HexString())
+ machine, err := h.GetMachineByMachineKey(machineKey.HexString())
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
newMachine := Machine{
Expiry: &time.Time{},
- MachineKey: mKey.HexString(),
+ MachineKey: machineKey.HexString(),
Name: req.Hostinfo.Hostname,
}
if err := h.db.Create(&newMachine).Error; err != nil {
@@ -95,88 +103,96 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration").
Err(err).
Msg("Could not create row")
- machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc()
+ machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name).
+ Inc()
+
return
}
- m = &newMachine
+ machine = &newMachine
}
- if !m.Registered && req.Auth.AuthKey != "" {
- h.handleAuthKey(c, h.db, mKey, req, *m)
+ if !machine.Registered && req.Auth.AuthKey != "" {
+ h.handleAuthKey(ctx, h.db, machineKey, req, *machine)
+
return
}
resp := tailcfg.RegisterResponse{}
// We have the updated key!
- if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
-
+ if machine.NodeKey == wgkey.Key(req.NodeKey).HexString() {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Info().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Client requested logout")
- m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
- h.db.Save(&m)
+ machine.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
+ h.db.Save(&machine)
resp.AuthURL = ""
resp.MachineAuthorized = false
- resp.User = *m.Namespace.toUser()
- respBody, err := encode(resp, &mKey, h.privateKey)
+ resp.User = *machine.Namespace.toUser()
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- c.String(http.StatusInternalServerError, "")
+ ctx.String(http.StatusInternalServerError, "")
+
return
}
- c.Data(200, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+
return
}
- if m.Registered && m.Expiry.UTC().After(now) {
+ if machine.Registered && machine.Expiry.UTC().After(now) {
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
- resp.User = *m.Namespace.toUser()
- resp.Login = *m.Namespace.toLogin()
+ resp.User = *machine.Namespace.toUser()
+ resp.Login = *machine.Namespace.toLogin()
- respBody, err := encode(resp, &mKey, h.privateKey)
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- machineRegistrations.WithLabelValues("update", "web", "error", m.Namespace.Name).Inc()
- c.String(http.StatusInternalServerError, "")
+ machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
+ Inc()
+ ctx.String(http.StatusInternalServerError, "")
+
return
}
- machineRegistrations.WithLabelValues("update", "web", "success", m.Namespace.Name).Inc()
- c.Data(200, "application/json; charset=utf-8", respBody)
+ machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
+ Inc()
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+
return
}
// The client has registered before, but has expired
log.Debug().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Machine registration has expired. Sending a authurl to register")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
- strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
+ strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
- strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
+ strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
}
// When a client connects, it may request a specific expiry time in its
@@ -185,102 +201,120 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// into two steps (which cant pass arbitrary data between them easily) and needs to be
// retrieved again after the user has authenticated. After the authentication flow
// completes, RequestedExpiry is copied into Expiry.
- m.RequestedExpiry = &req.Expiry
+ machine.RequestedExpiry = &req.Expiry
- h.db.Save(&m)
+ h.db.Save(&machine)
- respBody, err := encode(resp, &mKey, h.privateKey)
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- machineRegistrations.WithLabelValues("new", "web", "error", m.Namespace.Name).Inc()
- c.String(http.StatusInternalServerError, "")
+ machineRegistrations.WithLabelValues("new", "web", "error", machine.Namespace.Name).
+ Inc()
+ ctx.String(http.StatusInternalServerError, "")
+
return
}
- machineRegistrations.WithLabelValues("new", "web", "success", m.Namespace.Name).Inc()
- c.Data(200, "application/json; charset=utf-8", respBody)
+ machineRegistrations.WithLabelValues("new", "web", "success", machine.Namespace.Name).
+ Inc()
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+
return
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
- if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
+ if machine.NodeKey == wgkey.Key(req.OldNodeKey).HexString() &&
+ machine.Expiry.UTC().After(now) {
log.Debug().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh")
- m.NodeKey = wgkey.Key(req.NodeKey).HexString()
- h.db.Save(&m)
+ machine.NodeKey = wgkey.Key(req.NodeKey).HexString()
+ h.db.Save(&machine)
resp.AuthURL = ""
- resp.User = *m.Namespace.toUser()
- respBody, err := encode(resp, &mKey, h.privateKey)
+ resp.User = *machine.Namespace.toUser()
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- c.String(http.StatusInternalServerError, "Extremely sad!")
+ ctx.String(http.StatusInternalServerError, "Extremely sad!")
+
return
}
- c.Data(200, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
+
return
}
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("handler", "Registration").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
- resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
+ resp.AuthURL = fmt.Sprintf(
+ "%s/oidc/register/%s",
+ strings.TrimSuffix(h.cfg.ServerURL, "/"),
+ machineKey.HexString(),
+ )
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
- strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
+ strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.HexString())
}
// save the requested expiry time for retrieval later in the authentication flow
- m.RequestedExpiry = &req.Expiry
- m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
- h.db.Save(&m)
+ machine.RequestedExpiry = &req.Expiry
+ machine.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
+ h.db.Save(&machine)
- respBody, err := encode(resp, &mKey, h.privateKey)
+ respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
- c.String(http.StatusInternalServerError, "")
+ ctx.String(http.StatusInternalServerError, "")
+
return
}
- c.Data(200, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
}
-func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
+func (h *Headscale) getMapResponse(
+ machineKey wgkey.Key,
+ req tailcfg.MapRequest,
+ machine *Machine,
+) ([]byte, error) {
log.Trace().
Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response")
- node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
+ node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot convert to node")
+
return nil, err
}
- peers, err := h.getPeers(m)
+ peers, err := h.getPeers(machine)
if err != nil {
log.Error().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot fetch peers")
+
return nil, err
}
- profiles := getMapResponseUserProfiles(*m, peers)
+ profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
@@ -288,17 +322,16 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
Str("func", "getMapResponse").
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
+
return nil, err
}
- dnsConfig, err := getMapResponseDNSConfig(h.cfg.DNSConfig, h.cfg.BaseDomain, *m, peers)
- if err != nil {
- log.Error().
- Str("func", "getMapResponse").
- Err(err).
- Msg("Failed generate the DNSConfig")
- return nil, err
- }
+ dnsConfig := getMapResponseDNSConfig(
+ h.cfg.DNSConfig,
+ h.cfg.BaseDomain,
+ *machine,
+ peers,
+ )
resp := tailcfg.MapResponse{
KeepAlive: false,
@@ -323,66 +356,71 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
- respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
+ respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
- respBody, err = encode(resp, &mKey, h.privateKey)
+ respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
}
// declare the incoming size on the first 4 bytes
- data := make([]byte, 4)
+ data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
+
return data, nil
}
-func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Machine) ([]byte, error) {
- resp := tailcfg.MapResponse{
+func (h *Headscale) getMapKeepAliveResponse(
+ machineKey wgkey.Key,
+ mapRequest tailcfg.MapRequest,
+) ([]byte, error) {
+ mapResponse := tailcfg.MapResponse{
KeepAlive: true,
}
var respBody []byte
var err error
- if req.Compress == "zstd" {
- src, _ := json.Marshal(resp)
+ if mapRequest.Compress == "zstd" {
+ src, _ := json.Marshal(mapResponse)
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
- respBody, err = encodeMsg(srcCompressed, &mKey, h.privateKey)
+ respBody, err = encodeMsg(srcCompressed, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
} else {
- respBody, err = encode(resp, &mKey, h.privateKey)
+ respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
}
- data := make([]byte, 4)
+ data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
+
return data, nil
}
func (h *Headscale) handleAuthKey(
- c *gin.Context,
+ ctx *gin.Context,
db *gorm.DB,
idKey wgkey.Key,
- req tailcfg.RegisterRequest,
- m Machine,
+ reqisterRequest tailcfg.RegisterRequest,
+ machine Machine,
) {
log.Debug().
Str("func", "handleAuthKey").
- Str("machine", req.Hostinfo.Hostname).
- Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
+ Str("machine", reqisterRequest.Hostinfo.Hostname).
+ Msgf("Processing auth key for %s", reqisterRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
- pak, err := h.checkKeyValidity(req.Auth.AuthKey)
+ pak, err := h.checkKeyValidity(reqisterRequest.Auth.AuthKey)
if err != nil {
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
@@ -390,48 +428,56 @@ func (h *Headscale) handleAuthKey(
if err != nil {
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Err(err).
Msg("Cannot encode message")
- c.String(http.StatusInternalServerError, "")
- machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
+ ctx.String(http.StatusInternalServerError, "")
+ machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
+ Inc()
+
return
}
- c.Data(401, "application/json; charset=utf-8", respBody)
+ ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody)
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Failed authentication via AuthKey")
- machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
+ machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
+ Inc()
+
return
}
log.Debug().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Authentication key was valid, proceeding to acquire an IP address")
ip, err := h.getAvailableIP()
if err != nil {
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Failed to find an available IP")
- machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
+ machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
+ Inc()
+
return
}
log.Info().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("ip", ip.String()).
- Msgf("Assigning %s to %s", ip, m.Name)
+ Msgf("Assigning %s to %s", ip, machine.Name)
- m.AuthKeyID = uint(pak.ID)
- m.IPAddress = ip.String()
- m.NamespaceID = pak.NamespaceID
- m.NodeKey = wgkey.Key(req.NodeKey).HexString() // we update it just in case
- m.Registered = true
- m.RegisterMethod = "authKey"
- db.Save(&m)
+ machine.AuthKeyID = uint(pak.ID)
+ machine.IPAddress = ip.String()
+ machine.NamespaceID = pak.NamespaceID
+ machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).
+ HexString()
+ // we update it just in case
+ machine.Registered = true
+ machine.RegisterMethod = "authKey"
+ db.Save(&machine)
pak.Used = true
db.Save(&pak)
@@ -442,18 +488,21 @@ func (h *Headscale) handleAuthKey(
if err != nil {
log.Error().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Err(err).
Msg("Cannot encode message")
- machineRegistrations.WithLabelValues("new", "authkey", "error", m.Namespace.Name).Inc()
- c.String(http.StatusInternalServerError, "Extremely sad!")
+ machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name).
+ Inc()
+ ctx.String(http.StatusInternalServerError, "Extremely sad!")
+
return
}
- machineRegistrations.WithLabelValues("new", "authkey", "success", m.Namespace.Name).Inc()
- c.Data(200, "application/json; charset=utf-8", respBody)
+ machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name).
+ Inc()
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody)
log.Info().
Str("func", "handleAuthKey").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Successfully authenticated via AuthKey")
}
diff --git a/app.go b/app.go
index c226062..08b67fe 100644
--- a/app.go
+++ b/app.go
@@ -18,20 +18,19 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
- "github.com/patrickmn/go-cache"
- "golang.org/x/oauth2"
-
"github.com/gin-gonic/gin"
- "github.com/grpc-ecosystem/go-grpc-middleware"
+ grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
- "github.com/philip-bui/grpc-zerolog"
+ "github.com/patrickmn/go-cache"
+ zerolog "github.com/philip-bui/grpc-zerolog"
zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/soheilhy/cmux"
ginprometheus "github.com/zsais/go-gin-prometheus"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
+ "golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@@ -48,7 +47,16 @@ import (
)
const (
- AUTH_PREFIX = "Bearer "
+ AuthPrefix = "Bearer "
+ Postgres = "postgresql"
+ Sqlite = "sqlite3"
+ updateInterval = 5000
+ HTTPReadTimeout = 30 * time.Second
+
+ errUnsupportedDatabase = Error("unsupported DB")
+ errUnsupportedLetsEncryptChallengeType = Error(
+ "unknown value for Lets Encrypt challenge type",
+ )
)
// Config contains the initial Headscale configuration.
@@ -151,16 +159,22 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
var dbString string
switch cfg.DBtype {
- case "postgres":
- dbString = fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable", cfg.DBhost,
- cfg.DBport, cfg.DBname, cfg.DBuser, cfg.DBpass)
- case "sqlite3":
+ case Postgres:
+ dbString = fmt.Sprintf(
+ "host=%s port=%d dbname=%s user=%s password=%s sslmode=disable",
+ cfg.DBhost,
+ cfg.DBport,
+ cfg.DBname,
+ cfg.DBuser,
+ cfg.DBpass,
+ )
+ case Sqlite:
dbString = cfg.DBpath
default:
- return nil, errors.New("unsupported DB")
+ return nil, errUnsupportedDatabase
}
- h := Headscale{
+ app := Headscale{
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
@@ -169,33 +183,32 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
aclRules: tailcfg.FilterAllowAll, // default allowall
}
- err = h.initDB()
+ err = app.initDB()
if err != nil {
return nil, err
}
if cfg.OIDC.Issuer != "" {
- err = h.initOIDC()
+ err = app.initOIDC()
if err != nil {
return nil, err
}
}
- if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
- magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
- if err != nil {
- return nil, err
- }
+ if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS
+ magicDNSDomains := generateMagicDNSRootDomains(
+ app.cfg.IPPrefix,
+ )
// we might have routes already from Split DNS
- if h.cfg.DNSConfig.Routes == nil {
- h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
+ if app.cfg.DNSConfig.Routes == nil {
+ app.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
}
for _, d := range magicDNSDomains {
- h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
+ app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
}
}
- return &h, nil
+ return &app, nil
}
// Redirect to our TLS url.
@@ -221,30 +234,37 @@ func (h *Headscale) expireEphemeralNodesWorker() {
return
}
- for _, ns := range namespaces {
- machines, err := h.ListMachinesInNamespace(ns.Name)
+ for _, namespace := range namespaces {
+ machines, err := h.ListMachinesInNamespace(namespace.Name)
if err != nil {
- log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace")
+ log.Error().
+ Err(err).
+ Str("namespace", namespace.Name).
+ Msg("Error listing machines in namespace")
return
}
- for _, m := range machines {
- if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral &&
- time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
- log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database")
+ for _, machine := range machines {
+ if machine.AuthKey != nil && machine.LastSeen != nil &&
+ machine.AuthKey.Ephemeral &&
+ time.Now().
+ After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
+ log.Info().
+ Str("machine", machine.Name).
+ Msg("Ephemeral client removed from database")
- err = h.db.Unscoped().Delete(m).Error
+ err = h.db.Unscoped().Delete(machine).Error
if err != nil {
log.Error().
Err(err).
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("🤮 Cannot delete ephemeral machine from the database")
}
}
}
- h.setLastStateChangeToNow(ns.Name)
+ h.setLastStateChangeToNow(namespace.Name)
}
}
@@ -266,36 +286,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
-
// Check if the request is coming from the on-server client.
// This is not secure, but it is to maintain maintainability
// with the "legacy" database-based client
// It is also neede for grpc-gateway to be able to connect to
// the server
- p, _ := peer.FromContext(ctx)
+ client, _ := peer.FromContext(ctx)
- log.Trace().Caller().Str("client_address", p.Addr.String()).Msg("Client is trying to authenticate")
+ log.Trace().
+ Caller().
+ Str("client_address", client.Addr.String()).
+ Msg("Client is trying to authenticate")
- md, ok := metadata.FromIncomingContext(ctx)
+ meta, ok := metadata.FromIncomingContext(ctx)
if !ok {
- log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Retrieving metadata is failed")
- return ctx, status.Errorf(codes.InvalidArgument, "Retrieving metadata is failed")
+ log.Error().
+ Caller().
+ Str("client_address", client.Addr.String()).
+ Msg("Retrieving metadata is failed")
+
+ return ctx, status.Errorf(
+ codes.InvalidArgument,
+ "Retrieving metadata is failed",
+ )
}
- authHeader, ok := md["authorization"]
+ authHeader, ok := meta["authorization"]
if !ok {
- log.Error().Caller().Str("client_address", p.Addr.String()).Msg("Authorization token is not supplied")
- return ctx, status.Errorf(codes.Unauthenticated, "Authorization token is not supplied")
+ log.Error().
+ Caller().
+ Str("client_address", client.Addr.String()).
+ Msg("Authorization token is not supplied")
+
+ return ctx, status.Errorf(
+ codes.Unauthenticated,
+ "Authorization token is not supplied",
+ )
}
token := authHeader[0]
- if !strings.HasPrefix(token, AUTH_PREFIX) {
+ if !strings.HasPrefix(token, AuthPrefix) {
log.Error().
Caller().
- Str("client_address", p.Addr.String()).
+ Str("client_address", client.Addr.String()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
- return ctx, status.Error(codes.Unauthenticated, `missing "Bearer " prefix in "Authorization" header`)
+
+ return ctx, status.Error(
+ codes.Unauthenticated,
+ `missing "Bearer " prefix in "Authorization" header`,
+ )
}
// TODO(kradalby): Implement API key backend:
@@ -307,35 +347,38 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
// Currently all other than localhost traffic is unauthorized, this is intentional to allow
// us to make use of gRPC for our CLI, but not having to implement any of the remote capabilities
// and API key auth
- return ctx, status.Error(codes.Unauthenticated, "Authentication is not implemented yet")
+ return ctx, status.Error(
+ codes.Unauthenticated,
+ "Authentication is not implemented yet",
+ )
- //if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
- // log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
- // return ctx, status.Error(codes.Unauthenticated, "invalid token")
- //}
+ // if strings.TrimPrefix(token, AUTH_PREFIX) != a.Token {
+ // log.Error().Caller().Str("client_address", p.Addr.String()).Msg("invalid token")
+ // return ctx, status.Error(codes.Unauthenticated, "invalid token")
+ // }
// return handler(ctx, req)
}
-func (h *Headscale) httpAuthenticationMiddleware(c *gin.Context) {
+func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) {
log.Trace().
Caller().
- Str("client_address", c.ClientIP()).
+ Str("client_address", ctx.ClientIP()).
Msg("HTTP authentication invoked")
- authHeader := c.GetHeader("authorization")
+ authHeader := ctx.GetHeader("authorization")
- if !strings.HasPrefix(authHeader, AUTH_PREFIX) {
+ if !strings.HasPrefix(authHeader, AuthPrefix) {
log.Error().
Caller().
- Str("client_address", c.ClientIP()).
+ Str("client_address", ctx.ClientIP()).
Msg(`missing "Bearer " prefix in "Authorization" header`)
- c.AbortWithStatus(http.StatusUnauthorized)
+ ctx.AbortWithStatus(http.StatusUnauthorized)
return
}
- c.AbortWithStatus(http.StatusUnauthorized)
+ ctx.AbortWithStatus(http.StatusUnauthorized)
// TODO(kradalby): Implement API key backend
// Currently all traffic is unauthorized, this is intentional to allow
@@ -359,6 +402,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
if _, err := os.Stat(h.cfg.UnixSocket); errors.Is(err, os.ErrNotExist) {
return nil
}
+
return os.Remove(h.cfg.UnixSocket)
}
@@ -401,14 +445,17 @@ func (h *Headscale) Serve() error {
// Create the cmux object that will multiplex 2 protocols on the same port.
// The two following listeners will be served on the same port below gracefully.
- m := cmux.New(networkListener)
+ networkMutex := cmux.New(networkListener)
// Match gRPC requests here
- grpcListener := m.MatchWithWriters(
+ grpcListener := networkMutex.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
- cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc+proto"),
+ cmux.HTTP2MatchHeaderFieldSendSettings(
+ "content-type",
+ "application/grpc+proto",
+ ),
)
// Otherwise match regular http requests.
- httpListener := m.Match(cmux.Any())
+ httpListener := networkMutex.Match(cmux.Any())
grpcGatewayMux := runtime.NewServeMux()
@@ -431,30 +478,33 @@ func (h *Headscale) Serve() error {
return err
}
- r := gin.Default()
+ router := gin.Default()
- p := ginprometheus.NewPrometheus("gin")
- p.Use(r)
+ prometheus := ginprometheus.NewPrometheus("gin")
+ prometheus.Use(router)
- r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) })
- r.GET("/key", h.KeyHandler)
- r.GET("/register", h.RegisterWebAPI)
- r.POST("/machine/:id/map", h.PollNetMapHandler)
- r.POST("/machine/:id", h.RegistrationHandler)
- r.GET("/oidc/register/:mkey", h.RegisterOIDC)
- r.GET("/oidc/callback", h.OIDCCallback)
- r.GET("/apple", h.AppleMobileConfig)
- r.GET("/apple/:platform", h.ApplePlatformConfig)
- r.GET("/swagger", SwaggerUI)
- r.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
+ router.GET(
+ "/health",
+ func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
+ )
+ router.GET("/key", h.KeyHandler)
+ router.GET("/register", h.RegisterWebAPI)
+ router.POST("/machine/:id/map", h.PollNetMapHandler)
+ router.POST("/machine/:id", h.RegistrationHandler)
+ router.GET("/oidc/register/:mkey", h.RegisterOIDC)
+ router.GET("/oidc/callback", h.OIDCCallback)
+ router.GET("/apple", h.AppleMobileConfig)
+ router.GET("/apple/:platform", h.ApplePlatformConfig)
+ router.GET("/swagger", SwaggerUI)
+ router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
- api := r.Group("/api")
+ api := router.Group("/api")
api.Use(h.httpAuthenticationMiddleware)
{
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
}
- r.NoRoute(stdoutHandler)
+ router.NoRoute(stdoutHandler)
// Fetch an initial DERP Map before we start serving
h.DERPMap = GetDERPMap(h.cfg.DERP)
@@ -466,14 +516,13 @@ func (h *Headscale) Serve() error {
}
// I HATE THIS
- updateMillisecondsWait := int64(5000)
- go h.watchForKVUpdates(updateMillisecondsWait)
- go h.expireEphemeralNodes(updateMillisecondsWait)
+ go h.watchForKVUpdates(updateInterval)
+ go h.expireEphemeralNodes(updateInterval)
httpServer := &http.Server{
Addr: h.cfg.Addr,
- Handler: r,
- ReadTimeout: 30 * time.Second,
+ Handler: router,
+ ReadTimeout: HTTPReadTimeout,
// Go does not handle timeouts in HTTP very well, and there is
// no good way to handle streaming timeouts, therefore we need to
// keep this at unlimited and be careful to clean up connections
@@ -519,36 +568,40 @@ func (h *Headscale) Serve() error {
reflection.Register(grpcServer)
reflection.Register(grpcSocket)
- g := new(errgroup.Group)
+ errorGroup := new(errgroup.Group)
- g.Go(func() error { return grpcSocket.Serve(socketListener) })
+ errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
- g.Go(func() error { return grpcServer.Serve(grpcListener) })
+ errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
if tlsConfig != nil {
- g.Go(func() error {
+ errorGroup.Go(func() error {
tlsl := tls.NewListener(httpListener, tlsConfig)
+
return httpServer.Serve(tlsl)
})
} else {
- g.Go(func() error { return httpServer.Serve(httpListener) })
+ errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
}
- g.Go(func() error { return m.Serve() })
+ errorGroup.Go(func() error { return networkMutex.Serve() })
- log.Info().Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
+ log.Info().
+ Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)
- return g.Wait()
+ return errorGroup.Wait()
}
func (h *Headscale) getTLSSettings() (*tls.Config, error) {
+ var err error
if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
- log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
+ log.Warn().
+ Msg("Listening with TLS but ServerURL does not start with https://")
}
- m := autocert.Manager{
+ certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(h.cfg.TLSLetsEncryptHostname),
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
@@ -558,40 +611,44 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
Email: h.cfg.ACMEEmail,
}
- if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
+ switch h.cfg.TLSLetsEncryptChallengeType {
+ case "TLS-ALPN-01":
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
// The RFC requires that the validation is done on port 443; in other words, headscale
// must be reachable on port 443.
- return m.TLSConfig(), nil
- } else if h.cfg.TLSLetsEncryptChallengeType == "HTTP-01" {
+ return certManager.TLSConfig(), nil
+
+ case "HTTP-01":
// Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port.
go func() {
log.Fatal().
- Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
+ Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server")
}()
- return m.TLSConfig(), nil
- } else {
- return nil, errors.New("unknown value for TLSLetsEncryptChallengeType")
+ return certManager.TLSConfig(), nil
+
+ default:
+ return nil, errUnsupportedLetsEncryptChallengeType
}
} else if h.cfg.TLSCertPath == "" {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
}
- return nil, nil
+ return nil, err
} else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
}
- var err error
- tlsConfig := &tls.Config{}
- tlsConfig.ClientAuth = tls.RequireAnyClientCert
- tlsConfig.NextProtos = []string{"http/1.1"}
- tlsConfig.Certificates = make([]tls.Certificate, 1)
+ tlsConfig := &tls.Config{
+ ClientAuth: tls.RequireAnyClientCert,
+ NextProtos: []string{"http/1.1"},
+ Certificates: make([]tls.Certificate, 1),
+ MinVersion: tls.VersionTLS12,
+ }
tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
return tlsConfig, err
@@ -628,13 +685,13 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
}
}
-func stdoutHandler(c *gin.Context) {
- b, _ := io.ReadAll(c.Request.Body)
+func stdoutHandler(ctx *gin.Context) {
+ body, _ := io.ReadAll(ctx.Request.Body)
log.Trace().
- Interface("header", c.Request.Header).
- Interface("proto", c.Request.Proto).
- Interface("url", c.Request.URL).
- Bytes("body", b).
+ Interface("header", ctx.Request.Header).
+ Interface("proto", ctx.Request.Proto).
+ Interface("url", ctx.Request.URL).
+ Bytes("body", body).
Msg("Request did not match")
}
diff --git a/app_test.go b/app_test.go
index 5e53f1c..947062b 100644
--- a/app_test.go
+++ b/app_test.go
@@ -17,8 +17,10 @@ var _ = check.Suite(&Suite{})
type Suite struct{}
-var tmpDir string
-var h Headscale
+var (
+ tmpDir string
+ app Headscale
+)
func (s *Suite) SetUpTest(c *check.C) {
s.ResetDB(c)
@@ -41,18 +43,18 @@ func (s *Suite) ResetDB(c *check.C) {
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
}
- h = Headscale{
+ app = Headscale{
cfg: cfg,
dbType: "sqlite3",
dbString: tmpDir + "/headscale_test.db",
}
- err = h.initDB()
+ err = app.initDB()
if err != nil {
c.Fatal(err)
}
- db, err := h.openDB()
+ db, err := app.openDB()
if err != nil {
c.Fatal(err)
}
- h.db = db
+ app.db = db
}
diff --git a/apple_mobileconfig.go b/apple_mobileconfig.go
index f3956e3..2e454df 100644
--- a/apple_mobileconfig.go
+++ b/apple_mobileconfig.go
@@ -5,16 +5,15 @@ import (
"net/http"
"text/template"
- "github.com/rs/zerolog/log"
-
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
+ "github.com/rs/zerolog/log"
)
// AppleMobileConfig shows a simple message in the browser to point to the CLI
-// Listens in /register
-func (h *Headscale) AppleMobileConfig(c *gin.Context) {
- t := template.Must(template.New("apple").Parse(`
+// Listens in /register.
+func (h *Headscale) AppleMobileConfig(ctx *gin.Context) {
+ appleTemplate := template.Must(template.New("apple").Parse(`
Apple configuration profiles
@@ -56,7 +55,7 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
Or
Use your terminal to configure the default setting for Tailscale by issuing:
- defaults write io.tailscale.ipn.macos ControlURL {{.Url}}
+ defaults write io.tailscale.ipn.macos ControlURL {{.URL}}
Restart Tailscale.app and log in.
@@ -64,24 +63,29 @@ func (h *Headscale) AppleMobileConfig(c *gin.Context) {
`))
config := map[string]interface{}{
- "Url": h.cfg.ServerURL,
+ "URL": h.cfg.ServerURL,
}
var payload bytes.Buffer
- if err := t.Execute(&payload, config); err != nil {
+ if err := appleTemplate.Execute(&payload, config); err != nil {
log.Error().
Str("handler", "AppleMobileConfig").
Err(err).
Msg("Could not render Apple index template")
- c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple index template"))
+ ctx.Data(
+ http.StatusInternalServerError,
+ "text/html; charset=utf-8",
+ []byte("Could not render Apple index template"),
+ )
+
return
}
- c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
+ ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
}
-func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
- platform := c.Param("platform")
+func (h *Headscale) ApplePlatformConfig(ctx *gin.Context) {
+ platform := ctx.Param("platform")
id, err := uuid.NewV4()
if err != nil {
@@ -89,23 +93,33 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
- c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
+ ctx.Data(
+ http.StatusInternalServerError,
+ "text/html; charset=utf-8",
+ []byte("Failed to create UUID"),
+ )
+
return
}
- contentId, err := uuid.NewV4()
+ contentID, err := uuid.NewV4()
if err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
- c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Failed to create UUID"))
+ ctx.Data(
+ http.StatusInternalServerError,
+ "text/html; charset=utf-8",
+ []byte("Failed to create UUID"),
+ )
+
return
}
platformConfig := AppleMobilePlatformConfig{
- UUID: contentId,
- Url: h.cfg.ServerURL,
+ UUID: contentID,
+ URL: h.cfg.ServerURL,
}
var payload bytes.Buffer
@@ -117,7 +131,12 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple macOS template")
- c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple macOS template"))
+ ctx.Data(
+ http.StatusInternalServerError,
+ "text/html; charset=utf-8",
+ []byte("Could not render Apple macOS template"),
+ )
+
return
}
case "ios":
@@ -126,17 +145,27 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple iOS template")
- c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple iOS template"))
+ ctx.Data(
+ http.StatusInternalServerError,
+ "text/html; charset=utf-8",
+ []byte("Could not render Apple iOS template"),
+ )
+
return
}
default:
- c.Data(http.StatusOK, "text/html; charset=utf-8", []byte("Invalid platform, only ios and macos is supported"))
+ ctx.Data(
+ http.StatusOK,
+ "text/html; charset=utf-8",
+ []byte("Invalid platform, only ios and macos is supported"),
+ )
+
return
}
config := AppleMobileConfig{
UUID: id,
- Url: h.cfg.ServerURL,
+ URL: h.cfg.ServerURL,
Payload: payload.String(),
}
@@ -146,25 +175,35 @@ func (h *Headscale) ApplePlatformConfig(c *gin.Context) {
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple platform template")
- c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Apple platform template"))
+ ctx.Data(
+ http.StatusInternalServerError,
+ "text/html; charset=utf-8",
+ []byte("Could not render Apple platform template"),
+ )
+
return
}
- c.Data(http.StatusOK, "application/x-apple-aspen-config; charset=utf-8", content.Bytes())
+ ctx.Data(
+ http.StatusOK,
+ "application/x-apple-aspen-config; charset=utf-8",
+ content.Bytes(),
+ )
}
type AppleMobileConfig struct {
UUID uuid.UUID
- Url string
+ URL string
Payload string
}
type AppleMobilePlatformConfig struct {
UUID uuid.UUID
- Url string
+ URL string
}
-var commonTemplate = template.Must(template.New("mobileconfig").Parse(`
+var commonTemplate = template.Must(
+ template.New("mobileconfig").Parse(`
@@ -173,7 +212,7 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`PayloadDisplayName
Headscale
PayloadDescription
- Configure Tailscale login server to: {{.Url}}
+ Configure Tailscale login server to: {{.URL}}
PayloadIdentifier
com.github.juanfont.headscale
PayloadRemovalDisallowed
@@ -187,7 +226,8 @@ var commonTemplate = template.Must(template.New("mobileconfig").Parse(`
-`))
+`),
+)
var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
@@ -203,7 +243,7 @@ var iosTemplate = template.Must(template.New("iosTemplate").Parse(`
ControlURL
- {{.Url}}
+ {{.URL}}
`))
@@ -221,6 +261,6 @@ var macosTemplate = template.Must(template.New("macosTemplate").Parse(`
ControlURL
- {{.Url}}
+ {{.URL}}
`))
diff --git a/cli_test.go b/cli_test.go
index 291b5df..44ef9f0 100644
--- a/cli_test.go
+++ b/cli_test.go
@@ -7,31 +7,34 @@ import (
)
func (s *Suite) TestRegisterMachine(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
now := time.Now().UTC()
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
IPAddress: "10.0.0.1",
Expiry: &now,
RequestedExpiry: &now,
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- _, err = h.GetMachine("test", "testmachine")
+ _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
- m2, err := h.RegisterMachine("8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", n.Name)
+ machineAfterRegistering, err := app.RegisterMachine(
+ "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
+ namespace.Name,
+ )
c.Assert(err, check.IsNil)
- c.Assert(m2.Registered, check.Equals, true)
+ c.Assert(machineAfterRegistering.Registered, check.Equals, true)
- _, err = m2.GetHostInfo()
+ _, err = machineAfterRegistering.GetHostInfo()
c.Assert(err, check.IsNil)
}
diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go
index e140156..46bdb9e 100644
--- a/cmd/headscale/cli/debug.go
+++ b/cmd/headscale/cli/debug.go
@@ -27,7 +27,8 @@ func init() {
if err != nil {
log.Fatal().Err(err).Msg("")
}
- createNodeCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
+ createNodeCmd.Flags().
+ StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to advertise")
debugCmd.AddCommand(createNodeCmd)
}
@@ -47,6 +48,7 @@ var createNodeCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
+
return
}
@@ -56,19 +58,34 @@ var createNodeCmd = &cobra.Command{
name, err := cmd.Flags().GetString("name")
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting node from flag: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting node from flag: %s", err),
+ output,
+ )
+
return
}
machineKey, err := cmd.Flags().GetString("key")
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting key from flag: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting key from flag: %s", err),
+ output,
+ )
+
return
}
routes, err := cmd.Flags().GetStringSlice("route")
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting routes from flag: %s", err),
+ output,
+ )
+
return
}
@@ -81,7 +98,12 @@ var createNodeCmd = &cobra.Command{
response, err := client.DebugCreateMachine(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Cannot create machine: %s", status.Convert(err).Message()),
+ output,
+ )
+
return
}
diff --git a/cmd/headscale/cli/namespaces.go b/cmd/headscale/cli/namespaces.go
index 0e5eeb4..361e9be 100644
--- a/cmd/headscale/cli/namespaces.go
+++ b/cmd/headscale/cli/namespaces.go
@@ -4,6 +4,7 @@ import (
"fmt"
survey "github.com/AlecAivazis/survey/v2"
+ "github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm"
"github.com/rs/zerolog/log"
@@ -19,6 +20,10 @@ func init() {
namespaceCmd.AddCommand(renameNamespaceCmd)
}
+const (
+ errMissingParameter = headscale.Error("missing parameters")
+)
+
var namespaceCmd = &cobra.Command{
Use: "namespaces",
Short: "Manage the namespaces of Headscale",
@@ -29,8 +34,9 @@ var createNamespaceCmd = &cobra.Command{
Short: "Creates a new namespace",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
- return fmt.Errorf("Missing parameters")
+ return errMissingParameter
}
+
return nil
},
Run: func(cmd *cobra.Command, args []string) {
@@ -49,7 +55,15 @@ var createNamespaceCmd = &cobra.Command{
log.Trace().Interface("request", request).Msg("Sending CreateNamespace request")
response, err := client.CreateNamespace(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot create namespace: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf(
+ "Cannot create namespace: %s",
+ status.Convert(err).Message(),
+ ),
+ output,
+ )
+
return
}
@@ -62,8 +76,9 @@ var destroyNamespaceCmd = &cobra.Command{
Short: "Destroys a namespace",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
- return fmt.Errorf("Missing parameters")
+ return errMissingParameter
}
+
return nil
},
Run: func(cmd *cobra.Command, args []string) {
@@ -81,7 +96,12 @@ var destroyNamespaceCmd = &cobra.Command{
_, err := client.GetNamespace(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error: %s", status.Convert(err).Message()),
+ output,
+ )
+
return
}
@@ -89,7 +109,10 @@ var destroyNamespaceCmd = &cobra.Command{
force, _ := cmd.Flags().GetBool("force")
if !force {
prompt := &survey.Confirm{
- Message: fmt.Sprintf("Do you want to remove the namespace '%s' and any associated preauthkeys?", namespaceName),
+ Message: fmt.Sprintf(
+ "Do you want to remove the namespace '%s' and any associated preauthkeys?",
+ namespaceName,
+ ),
}
err := survey.AskOne(prompt, &confirm)
if err != nil {
@@ -102,7 +125,15 @@ var destroyNamespaceCmd = &cobra.Command{
response, err := client.DeleteNamespace(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot destroy namespace: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf(
+ "Cannot destroy namespace: %s",
+ status.Convert(err).Message(),
+ ),
+ output,
+ )
+
return
}
SuccessOutput(response, "Namespace destroyed", output)
@@ -126,19 +157,25 @@ var listNamespacesCmd = &cobra.Command{
response, err := client.ListNamespaces(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Cannot get namespaces: %s", status.Convert(err).Message()),
+ output,
+ )
+
return
}
if output != "" {
SuccessOutput(response.Namespaces, "", output)
+
return
}
- d := pterm.TableData{{"ID", "Name", "Created"}}
+ tableData := pterm.TableData{{"ID", "Name", "Created"}}
for _, namespace := range response.GetNamespaces() {
- d = append(
- d,
+ tableData = append(
+ tableData,
[]string{
namespace.GetId(),
namespace.GetName(),
@@ -146,9 +183,14 @@ var listNamespacesCmd = &cobra.Command{
},
)
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Failed to render pterm table: %s", err),
+ output,
+ )
+
return
}
},
@@ -158,9 +200,11 @@ var renameNamespaceCmd = &cobra.Command{
Use: "rename OLD_NAME NEW_NAME",
Short: "Renames a namespace",
Args: func(cmd *cobra.Command, args []string) error {
- if len(args) < 2 {
- return fmt.Errorf("Missing parameters")
+ expectedArguments := 2
+ if len(args) < expectedArguments {
+ return errMissingParameter
}
+
return nil
},
Run: func(cmd *cobra.Command, args []string) {
@@ -177,7 +221,15 @@ var renameNamespaceCmd = &cobra.Command{
response, err := client.RenameNamespace(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot rename namespace: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf(
+ "Cannot rename namespace: %s",
+ status.Convert(err).Message(),
+ ),
+ output,
+ )
+
return
}
diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go
index 9b3f424..218d25b 100644
--- a/cmd/headscale/cli/nodes.go
+++ b/cmd/headscale/cli/nodes.go
@@ -7,6 +7,7 @@ import (
"time"
survey "github.com/AlecAivazis/survey/v2"
+ "github.com/juanfont/headscale"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
@@ -77,6 +78,7 @@ var registerNodeCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
+
return
}
@@ -86,7 +88,12 @@ var registerNodeCmd = &cobra.Command{
machineKey, err := cmd.Flags().GetString("key")
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting machine key from flag: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting machine key from flag: %s", err),
+ output,
+ )
+
return
}
@@ -97,7 +104,15 @@ var registerNodeCmd = &cobra.Command{
response, err := client.RegisterMachine(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf(
+ "Cannot register machine: %s\n",
+ status.Convert(err).Message(),
+ ),
+ output,
+ )
+
return
}
@@ -113,6 +128,7 @@ var listNodesCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
+
return
}
@@ -126,24 +142,36 @@ var listNodesCmd = &cobra.Command{
response, err := client.ListMachines(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
+ output,
+ )
+
return
}
if output != "" {
SuccessOutput(response.Machines, "", output)
+
return
}
- d, err := nodesToPtables(namespace, response.Machines)
+ tableData, err := nodesToPtables(namespace, response.Machines)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
+
return
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Failed to render pterm table: %s", err),
+ output,
+ )
+
return
}
},
@@ -155,9 +183,14 @@ var deleteNodeCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
- id, err := cmd.Flags().GetInt("identifier")
+ identifier, err := cmd.Flags().GetInt("identifier")
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error converting ID to integer: %s", err),
+ output,
+ )
+
return
}
@@ -166,24 +199,35 @@ var deleteNodeCmd = &cobra.Command{
defer conn.Close()
getRequest := &v1.GetMachineRequest{
- MachineId: uint64(id),
+ MachineId: uint64(identifier),
}
getResponse, err := client.GetMachine(ctx, getRequest)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf(
+ "Error getting node node: %s",
+ status.Convert(err).Message(),
+ ),
+ output,
+ )
+
return
}
deleteRequest := &v1.DeleteMachineRequest{
- MachineId: uint64(id),
+ MachineId: uint64(identifier),
}
confirm := false
force, _ := cmd.Flags().GetBool("force")
if !force {
prompt := &survey.Confirm{
- Message: fmt.Sprintf("Do you want to remove the node %s?", getResponse.GetMachine().Name),
+ Message: fmt.Sprintf(
+ "Do you want to remove the node %s?",
+ getResponse.GetMachine().Name,
+ ),
}
err = survey.AskOne(prompt, &confirm)
if err != nil {
@@ -195,13 +239,26 @@ var deleteNodeCmd = &cobra.Command{
response, err := client.DeleteMachine(ctx, deleteRequest)
if output != "" {
SuccessOutput(response, "", output)
+
return
}
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error deleting node: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf(
+ "Error deleting node: %s",
+ status.Convert(err).Message(),
+ ),
+ output,
+ )
+
return
}
- SuccessOutput(map[string]string{"Result": "Node deleted"}, "Node deleted", output)
+ SuccessOutput(
+ map[string]string{"Result": "Node deleted"},
+ "Node deleted",
+ output,
+ )
} else {
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
}
@@ -210,12 +267,12 @@ var deleteNodeCmd = &cobra.Command{
func sharingWorker(
cmd *cobra.Command,
- args []string,
) (string, *v1.Machine, *v1.Namespace, error) {
output, _ := cmd.Flags().GetString("output")
namespaceStr, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
+
return "", nil, nil, err
}
@@ -223,19 +280,25 @@ func sharingWorker(
defer cancel()
defer conn.Close()
- id, err := cmd.Flags().GetInt("identifier")
+ identifier, err := cmd.Flags().GetInt("identifier")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting ID to integer: %s", err), output)
+
return "", nil, nil, err
}
machineRequest := &v1.GetMachineRequest{
- MachineId: uint64(id),
+ MachineId: uint64(identifier),
}
machineResponse, err := client.GetMachine(ctx, machineRequest)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
+ output,
+ )
+
return "", nil, nil, err
}
@@ -245,7 +308,12 @@ func sharingWorker(
namespaceResponse, err := client.GetNamespace(ctx, namespaceRequest)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting node node: %s", status.Convert(err).Message()),
+ output,
+ )
+
return "", nil, nil, err
}
@@ -256,9 +324,14 @@ var shareMachineCmd = &cobra.Command{
Use: "share",
Short: "Shares a node from the current namespace to the specified one",
Run: func(cmd *cobra.Command, args []string) {
- output, machine, namespace, err := sharingWorker(cmd, args)
+ output, machine, namespace, err := sharingWorker(cmd)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
+ output,
+ )
+
return
}
@@ -273,7 +346,12 @@ var shareMachineCmd = &cobra.Command{
response, err := client.ShareMachine(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error sharing node: %s", status.Convert(err).Message()),
+ output,
+ )
+
return
}
@@ -285,9 +363,14 @@ var unshareMachineCmd = &cobra.Command{
Use: "unshare",
Short: "Unshares a node from the specified namespace",
Run: func(cmd *cobra.Command, args []string) {
- output, machine, namespace, err := sharingWorker(cmd, args)
+ output, machine, namespace, err := sharingWorker(cmd)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Failed to fetch namespace or machine: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Failed to fetch namespace or machine: %s", err),
+ output,
+ )
+
return
}
@@ -302,7 +385,12 @@ var unshareMachineCmd = &cobra.Command{
response, err := client.UnshareMachine(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error unsharing node: %s", status.Convert(err).Message()),
+ output,
+ )
+
return
}
@@ -310,8 +398,22 @@ var unshareMachineCmd = &cobra.Command{
},
}
-func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.TableData, error) {
- d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}}
+func nodesToPtables(
+ currentNamespace string,
+ machines []*v1.Machine,
+) (pterm.TableData, error) {
+ tableData := pterm.TableData{
+ {
+ "ID",
+ "Name",
+ "NodeKey",
+ "Namespace",
+ "IP address",
+ "Ephemeral",
+ "Last seen",
+ "Online",
+ },
+ }
for _, machine := range machines {
var ephemeral bool
@@ -331,7 +433,9 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
nodeKey := tailcfg.NodeKey(nKey)
var online string
- if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
+ if lastSeen.After(
+ time.Now().Add(-5 * time.Minute),
+ ) { // TODO: Find a better way to reliably show if online
online = pterm.LightGreen("true")
} else {
online = pterm.LightRed("false")
@@ -344,10 +448,10 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
// Shared into this namespace
namespace = pterm.LightYellow(machine.Namespace.Name)
}
- d = append(
- d,
+ tableData = append(
+ tableData,
[]string{
- strconv.FormatUint(machine.Id, 10),
+ strconv.FormatUint(machine.Id, headscale.Base10),
machine.Name,
nodeKey.ShortString(),
namespace,
@@ -358,5 +462,6 @@ func nodesToPtables(currentNamespace string, machines []*v1.Machine) (pterm.Tabl
},
)
}
- return d, nil
+
+ return tableData, nil
}
diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go
index a07e961..9d5c838 100644
--- a/cmd/headscale/cli/preauthkeys.go
+++ b/cmd/headscale/cli/preauthkeys.go
@@ -12,6 +12,10 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)
+const (
+ DefaultPreAuthKeyExpiry = 24 * time.Hour
+)
+
func init() {
rootCmd.AddCommand(preauthkeysCmd)
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "Namespace")
@@ -22,10 +26,12 @@ func init() {
preauthkeysCmd.AddCommand(listPreAuthKeys)
preauthkeysCmd.AddCommand(createPreAuthKeyCmd)
preauthkeysCmd.AddCommand(expirePreAuthKeyCmd)
- createPreAuthKeyCmd.PersistentFlags().Bool("reusable", false, "Make the preauthkey reusable")
- createPreAuthKeyCmd.PersistentFlags().Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
+ createPreAuthKeyCmd.PersistentFlags().
+ Bool("reusable", false, "Make the preauthkey reusable")
+ createPreAuthKeyCmd.PersistentFlags().
+ Bool("ephemeral", false, "Preauthkey for ephemeral nodes")
createPreAuthKeyCmd.Flags().
- DurationP("expiration", "e", 24*time.Hour, "Human-readable expiration of the key (30m, 24h, 365d...)")
+ DurationP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (30m, 24h, 365d...)")
}
var preauthkeysCmd = &cobra.Command{
@@ -39,9 +45,10 @@ var listPreAuthKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
- n, err := cmd.Flags().GetString("namespace")
+ namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
+
return
}
@@ -50,48 +57,61 @@ var listPreAuthKeys = &cobra.Command{
defer conn.Close()
request := &v1.ListPreAuthKeysRequest{
- Namespace: n,
+ Namespace: namespace,
}
response, err := client.ListPreAuthKeys(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting the list of keys: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting the list of keys: %s", err),
+ output,
+ )
+
return
}
if output != "" {
SuccessOutput(response.PreAuthKeys, "", output)
+
return
}
- d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}}
- for _, k := range response.PreAuthKeys {
+ tableData := pterm.TableData{
+ {"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"},
+ }
+ for _, key := range response.PreAuthKeys {
expiration := "-"
- if k.GetExpiration() != nil {
- expiration = k.Expiration.AsTime().Format("2006-01-02 15:04:05")
+ if key.GetExpiration() != nil {
+ expiration = key.Expiration.AsTime().Format("2006-01-02 15:04:05")
}
var reusable string
- if k.GetEphemeral() {
+ if key.GetEphemeral() {
reusable = "N/A"
} else {
- reusable = fmt.Sprintf("%v", k.GetReusable())
+ reusable = fmt.Sprintf("%v", key.GetReusable())
}
- d = append(d, []string{
- k.GetId(),
- k.GetKey(),
+ tableData = append(tableData, []string{
+ key.GetId(),
+ key.GetKey(),
reusable,
- strconv.FormatBool(k.GetEphemeral()),
- strconv.FormatBool(k.GetUsed()),
+ strconv.FormatBool(key.GetEphemeral()),
+ strconv.FormatBool(key.GetUsed()),
expiration,
- k.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
+ key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
})
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Failed to render pterm table: %s", err),
+ output,
+ )
+
return
}
},
@@ -106,6 +126,7 @@ var createPreAuthKeyCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
+
return
}
@@ -139,7 +160,12 @@ var createPreAuthKeyCmd = &cobra.Command{
response, err := client.CreatePreAuthKey(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
+ output,
+ )
+
return
}
@@ -152,8 +178,9 @@ var expirePreAuthKeyCmd = &cobra.Command{
Short: "Expire a preauthkey",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
- return fmt.Errorf("missing parameters")
+ return errMissingParameter
}
+
return nil
},
Run: func(cmd *cobra.Command, args []string) {
@@ -161,6 +188,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting namespace: %s", err), output)
+
return
}
@@ -175,7 +203,12 @@ var expirePreAuthKeyCmd = &cobra.Command{
response, err := client.ExpirePreAuthKey(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
+ output,
+ )
+
return
}
diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go
index e5115f3..99b1514 100644
--- a/cmd/headscale/cli/root.go
+++ b/cmd/headscale/cli/root.go
@@ -10,7 +10,8 @@ import (
func init() {
rootCmd.PersistentFlags().
StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'")
- rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution")
+ rootCmd.PersistentFlags().
+ Bool("force", false, "Disable prompts and forces the execution")
}
var rootCmd = &cobra.Command{
diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go
index 9f2e4e2..ced1a0b 100644
--- a/cmd/headscale/cli/routes.go
+++ b/cmd/headscale/cli/routes.go
@@ -21,7 +21,8 @@ func init() {
}
routesCmd.AddCommand(listRoutesCmd)
- enableRouteCmd.Flags().StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
+ enableRouteCmd.Flags().
+ StringSliceP("route", "r", []string{}, "List (or repeated flags) of routes to enable")
enableRouteCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
err = enableRouteCmd.MarkFlagRequired("identifier")
if err != nil {
@@ -44,9 +45,14 @@ var listRoutesCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
- machineId, err := cmd.Flags().GetUint64("identifier")
+ machineID, err := cmd.Flags().GetUint64("identifier")
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting machine id from flag: %s", err),
+ output,
+ )
+
return
}
@@ -55,29 +61,41 @@ var listRoutesCmd = &cobra.Command{
defer conn.Close()
request := &v1.GetMachineRouteRequest{
- MachineId: machineId,
+ MachineId: machineID,
}
response, err := client.GetMachineRoute(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()),
+ output,
+ )
+
return
}
if output != "" {
SuccessOutput(response.Routes, "", output)
+
return
}
- d := routesToPtables(response.Routes)
+ tableData := routesToPtables(response.Routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
+
return
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Failed to render pterm table: %s", err),
+ output,
+ )
+
return
}
},
@@ -93,15 +111,26 @@ omit the route you do not want to enable.
`,
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
- machineId, err := cmd.Flags().GetUint64("identifier")
+
+ machineID, err := cmd.Flags().GetUint64("identifier")
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting machine id from flag: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting machine id from flag: %s", err),
+ output,
+ )
+
return
}
routes, err := cmd.Flags().GetStringSlice("route")
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Error getting routes from flag: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Error getting routes from flag: %s", err),
+ output,
+ )
+
return
}
@@ -110,45 +139,61 @@ omit the route you do not want to enable.
defer conn.Close()
request := &v1.EnableMachineRoutesRequest{
- MachineId: machineId,
+ MachineId: machineID,
Routes: routes,
}
response, err := client.EnableMachineRoutes(ctx, request)
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Cannot register machine: %s\n", status.Convert(err).Message()), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf(
+ "Cannot register machine: %s\n",
+ status.Convert(err).Message(),
+ ),
+ output,
+ )
+
return
}
if output != "" {
SuccessOutput(response.Routes, "", output)
+
return
}
- d := routesToPtables(response.Routes)
+ tableData := routesToPtables(response.Routes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
+
return
}
- err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
+ err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
- ErrorOutput(err, fmt.Sprintf("Failed to render pterm table: %s", err), output)
+ ErrorOutput(
+ err,
+ fmt.Sprintf("Failed to render pterm table: %s", err),
+ output,
+ )
+
return
}
},
}
-// routesToPtables converts the list of routes to a nice table
+// routesToPtables converts the list of routes to a nice table.
func routesToPtables(routes *v1.Routes) pterm.TableData {
- d := pterm.TableData{{"Route", "Enabled"}}
+ tableData := pterm.TableData{{"Route", "Enabled"}}
for _, route := range routes.GetAdvertisedRoutes() {
enabled := isStringInSlice(routes.EnabledRoutes, route)
- d = append(d, []string{route, strconv.FormatBool(enabled)})
+ tableData = append(tableData, []string{route, strconv.FormatBool(enabled)})
}
- return d
+
+ return tableData
}
func isStringInSlice(strs []string, s string) bool {
diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go
index 30f1a78..958fb89 100644
--- a/cmd/headscale/cli/utils.go
+++ b/cmd/headscale/cli/utils.go
@@ -52,9 +52,8 @@ func LoadConfig(path string) error {
viper.SetDefault("cli.insecure", false)
viper.SetDefault("cli.timeout", "5s")
- err := viper.ReadInConfig()
- if err != nil {
- return fmt.Errorf("Fatal error reading config file: %s \n", err)
+ if err := viper.ReadInConfig(); err != nil {
+ return fmt.Errorf("fatal error reading config file: %w", err)
}
// Collect any validation errors and return them all at once
@@ -82,6 +81,7 @@ func LoadConfig(path string) error {
errorText += "Fatal config error: server_url must start with https:// or http://\n"
}
if errorText != "" {
+ //nolint
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
@@ -149,9 +149,14 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config.restricted_nameservers") {
if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Routes = make(map[string][]dnstype.Resolver)
- restrictedDNS := viper.GetStringMapStringSlice("dns_config.restricted_nameservers")
+ restrictedDNS := viper.GetStringMapStringSlice(
+ "dns_config.restricted_nameservers",
+ )
for domain, restrictedNameservers := range restrictedDNS {
- restrictedResolvers := make([]dnstype.Resolver, len(restrictedNameservers))
+ restrictedResolvers := make(
+ []dnstype.Resolver,
+ len(restrictedNameservers),
+ )
for index, nameserverStr := range restrictedNameservers {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
@@ -208,6 +213,7 @@ func absPath(path string) string {
path = filepath.Join(dir, path)
}
}
+
return path
}
@@ -219,7 +225,9 @@ func getHeadscaleConfig() headscale.Config {
"10h",
) // use 10h here because it is the length of a standard business day plus a small amount of leeway
if viper.GetDuration("max_machine_registration_duration") >= time.Second {
- maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration")
+ maxMachineRegistrationDuration = viper.GetDuration(
+ "max_machine_registration_duration",
+ )
}
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
@@ -229,7 +237,9 @@ func getHeadscaleConfig() headscale.Config {
"8h",
) // use 8h here because it's the length of a standard business day
if viper.GetDuration("default_machine_registration_duration") >= time.Second {
- defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration")
+ defaultMachineRegistrationDuration = viper.GetDuration(
+ "default_machine_registration_duration",
+ )
}
dnsConfig, baseDomain := GetDNSConfig()
@@ -244,7 +254,9 @@ func getHeadscaleConfig() headscale.Config {
DERP: derpConfig,
- EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
+ EphemeralNodeInactivityTimeout: viper.GetDuration(
+ "ephemeral_node_inactivity_timeout",
+ ),
DBtype: viper.GetString("db_type"),
DBpath: absPath(viper.GetString("db_path")),
@@ -254,9 +266,11 @@ func getHeadscaleConfig() headscale.Config {
DBuser: viper.GetString("db_user"),
DBpass: viper.GetString("db_pass"),
- TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
- TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
- TLSLetsEncryptCacheDir: absPath(viper.GetString("tls_letsencrypt_cache_dir")),
+ TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
+ TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
+ TLSLetsEncryptCacheDir: absPath(
+ viper.GetString("tls_letsencrypt_cache_dir"),
+ ),
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
@@ -292,11 +306,14 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
// to avoid races
minInactivityTimeout, _ := time.ParseDuration("65s")
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
+ // TODO: Find a better way to return this text
+ //nolint
err := fmt.Errorf(
- "ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s\n",
+ "ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
viper.GetString("ephemeral_node_inactivity_timeout"),
minInactivityTimeout,
)
+
return nil, err
}
@@ -304,7 +321,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
cfg.OIDC.MatchMap = loadOIDCMatchMap()
- h, err := headscale.NewHeadscale(cfg)
+ app, err := headscale.NewHeadscale(cfg)
if err != nil {
return nil, err
}
@@ -313,7 +330,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
if viper.GetString("acl_policy_path") != "" {
aclPath := absPath(viper.GetString("acl_policy_path"))
- err = h.LoadACLPolicy(aclPath)
+ err = app.LoadACLPolicy(aclPath)
if err != nil {
log.Error().
Str("path", aclPath).
@@ -322,7 +339,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
}
}
- return h, nil
+ return app, nil
}
func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
@@ -342,7 +359,6 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.
// If the address is not set, we assume that we are on the server hosting headscale.
if address == "" {
-
log.Debug().
Str("socket", cfg.UnixSocket).
Msgf("HEADSCALE_CLI_ADDRESS environment is not set, connecting to unix socket.")
@@ -402,10 +418,13 @@ func SuccessOutput(result interface{}, override string, outputFormat string) {
log.Fatal().Err(err)
}
default:
+ //nolint
fmt.Println(override)
+
return
}
+ //nolint
fmt.Println(string(j))
}
@@ -423,6 +442,7 @@ func HasMachineOutputFlag() bool {
return true
}
}
+
return false
}
@@ -431,7 +451,10 @@ type tokenAuth struct {
}
// Return value is mapped to request headers.
-func (t tokenAuth) GetRequestMetadata(ctx context.Context, in ...string) (map[string]string, error) {
+func (t tokenAuth) GetRequestMetadata(
+ ctx context.Context,
+ in ...string,
+) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil
diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go
index ed4644f..d6bf216 100644
--- a/cmd/headscale/headscale.go
+++ b/cmd/headscale/headscale.go
@@ -23,6 +23,8 @@ func main() {
colors = true
case termcolor.LevelBasic:
colors = true
+ case termcolor.LevelNone:
+ colors = false
default:
// no color, return text as is.
colors = false
@@ -41,8 +43,7 @@ func main() {
NoColor: !colors,
})
- err := cli.LoadConfig("")
- if err != nil {
+ if err := cli.LoadConfig(""); err != nil {
log.Fatal().Err(err)
}
@@ -63,13 +64,15 @@ func main() {
}
if !viper.GetBool("disable_check_updates") && !machineOutput {
- if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" {
+ if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
+ cli.Version != "dev" {
githubTag := &latest.GithubTag{
Owner: "juanfont",
Repository: "headscale",
}
res, err := latest.Check(githubTag, cli.Version)
if err == nil && res.Outdated {
+ //nolint
fmt.Printf(
"An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n",
res.Current,
diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go
index e3a5713..1166d48 100644
--- a/cmd/headscale/headscale_test.go
+++ b/cmd/headscale/headscale_test.go
@@ -1,7 +1,6 @@
package main
import (
- "fmt"
"io/ioutil"
"os"
"path/filepath"
@@ -40,7 +39,10 @@ func (*Suite) TestConfigLoading(c *check.C) {
}
// Symlink the example config file
- err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
+ err = os.Symlink(
+ filepath.Clean(path+"/../../config-example.yaml"),
+ filepath.Join(tmpDir, "config.yaml"),
+ )
if err != nil {
c.Fatal(err)
}
@@ -74,7 +76,10 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
}
// Symlink the example config file
- err = os.Symlink(filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"))
+ err = os.Symlink(
+ filepath.Clean(path+"/../../config-example.yaml"),
+ filepath.Join(tmpDir, "config.yaml"),
+ )
if err != nil {
c.Fatal(err)
}
@@ -94,7 +99,7 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
// Populate a custom config file
configFile := filepath.Join(tmpDir, "config.yaml")
- err := ioutil.WriteFile(configFile, configYaml, 0o644)
+ err := ioutil.WriteFile(configFile, configYaml, 0o600)
if err != nil {
c.Fatalf("Couldn't write file %s", configFile)
}
@@ -106,7 +111,6 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
c.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
- fmt.Println(tmpDir)
configYaml := []byte(
"---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"",
@@ -128,8 +132,11 @@ func (*Suite) TestTLSConfigValidation(c *check.C) {
check.Matches,
".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*",
)
- c.Assert(tmp, check.Matches, ".*Fatal config error: server_url must start with https:// or http://.*")
- fmt.Println(tmp)
+ c.Assert(
+ tmp,
+ check.Matches,
+ ".*Fatal config error: server_url must start with https:// or http://.*",
+ )
// Check configuration validation errors (2)
configYaml = []byte(
diff --git a/db.go b/db.go
index 42c5eee..5136325 100644
--- a/db.go
+++ b/db.go
@@ -9,7 +9,10 @@ import (
"gorm.io/gorm/logger"
)
-const dbVersion = "1"
+const (
+ dbVersion = "1"
+ errValueNotFound = Error("not found")
+)
// KV is a key-value store in a psql table. For future use...
type KV struct {
@@ -24,7 +27,7 @@ func (h *Headscale) initDB() error {
}
h.db = db
- if h.dbType == "postgres" {
+ if h.dbType == Postgres {
db.Exec("create extension if not exists \"uuid-ossp\";")
}
err = db.AutoMigrate(&Machine{})
@@ -50,6 +53,7 @@ func (h *Headscale) initDB() error {
}
err = h.setValue("db_version", dbVersion)
+
return err
}
@@ -65,12 +69,12 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
}
switch h.dbType {
- case "sqlite3":
+ case Sqlite:
db, err = gorm.Open(sqlite.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: log,
})
- case "postgres":
+ case Postgres:
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: log,
@@ -84,28 +88,33 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
return db, nil
}
-// getValue returns the value for the given key in KV
+// getValue returns the value for the given key in KV.
func (h *Headscale) getValue(key string) (string, error) {
var row KV
- if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return "", errors.New("not found")
+ if result := h.db.First(&row, "key = ?", key); errors.Is(
+ result.Error,
+ gorm.ErrRecordNotFound,
+ ) {
+ return "", errValueNotFound
}
+
return row.Value, nil
}
-// setValue sets value for the given key in KV
+// setValue sets value for the given key in KV.
func (h *Headscale) setValue(key string, value string) error {
- kv := KV{
+ keyValue := KV{
Key: key,
Value: value,
}
- _, err := h.getValue(key)
- if err == nil {
- h.db.Model(&kv).Where("key = ?", key).Update("value", value)
+ if _, err := h.getValue(key); err == nil {
+ h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
+
return nil
}
- h.db.Create(kv)
+ h.db.Create(keyValue)
+
return nil
}
diff --git a/derp-example.yaml b/derp-example.yaml
index bbf7cc8..a9901be 100644
--- a/derp-example.yaml
+++ b/derp-example.yaml
@@ -1,15 +1,15 @@
# If you plan to somehow use headscale, please deploy your own DERP infra: https://tailscale.com/kb/1118/custom-derp-servers/
-regions:
+regions:
900:
regionid: 900
regioncode: custom
regionname: My Region
nodes:
- - name: 1a
- regionid: 1
- hostname: myderp.mydomain.no
- ipv4: 123.123.123.123
- ipv6: "2604:a880:400:d1::828:b001"
- stunport: 0
- stunonly: false
- derptestport: 0
+ - name: 1a
+ regionid: 1
+ hostname: myderp.mydomain.no
+ ipv4: 123.123.123.123
+ ipv6: "2604:a880:400:d1::828:b001"
+ stunport: 0
+ stunonly: false
+ derptestport: 0
diff --git a/derp.go b/derp.go
index 7f65832..63e448d 100644
--- a/derp.go
+++ b/derp.go
@@ -1,6 +1,7 @@
package headscale
import (
+ "context"
"encoding/json"
"io"
"io/ioutil"
@@ -10,9 +11,7 @@ import (
"time"
"github.com/rs/zerolog/log"
-
"gopkg.in/yaml.v2"
-
"tailscale.com/tailcfg"
)
@@ -28,14 +27,24 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
return nil, err
}
err = yaml.Unmarshal(b, &derpMap)
+
return &derpMap, err
}
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
- client := http.Client{
- Timeout: 10 * time.Second,
+ ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
+ if err != nil {
+ return nil, err
}
- resp, err := client.Get(addr.String())
+
+ client := http.Client{
+ Timeout: HTTPReadTimeout,
+ }
+
+ resp, err := client.Do(req)
if err != nil {
return nil, err
}
@@ -48,6 +57,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
var derpMap tailcfg.DERPMap
err = json.Unmarshal(body, &derpMap)
+
return &derpMap, err
}
@@ -55,7 +65,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
// DERPMap, it will _only_ look at the Regions, an integer.
// If a region exists in two of the given DERPMaps, the region
// form the _last_ DERPMap will be preserved.
-// An empty DERPMap list will result in a DERPMap with no regions
+// An empty DERPMap list will result in a DERPMap with no regions.
func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
result := tailcfg.DERPMap{
OmitDefaultRegions: false,
@@ -86,6 +96,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
Str("path", path).
Err(err).
Msg("Could not load DERP map from path")
+
break
}
@@ -104,6 +115,7 @@ func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
Str("url", addr.String()).
Err(err).
Msg("Could not load DERP map from path")
+
break
}
diff --git a/dns.go b/dns.go
index c7ca32a..af6f989 100644
--- a/dns.go
+++ b/dns.go
@@ -10,6 +10,10 @@ import (
"tailscale.com/util/dnsname"
)
+const (
+ ByteSize = 8
+)
+
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
// server (listening in 100.100.100.100 udp/53) should be used for.
@@ -30,7 +34,9 @@ import (
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
-func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) {
+func generateMagicDNSRootDomains(
+ ipPrefix netaddr.IPPrefix,
+) []dnsname.FQDN {
// TODO(juanfont): we are not handing out IPv6 addresses yet
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
@@ -41,15 +47,15 @@ func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) (
maskBits, _ := netRange.Mask.Size()
// lastOctet is the last IP byte covered by the mask
- lastOctet := maskBits / 8
+ lastOctet := maskBits / ByteSize
// wildcardBits is the number of bits not under the mask in the lastOctet
- wildcardBits := 8 - maskBits%8
+ wildcardBits := ByteSize - maskBits%ByteSize
// min is the value in the lastOctet byte of the IP
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
min := uint(netRange.IP[lastOctet])
- max := uint((min + 1< config/private.key
cp config.yaml.[sqlite|postgres].example config/config.yaml
-
+
cp derp-example.yaml config/derp.yaml
```
@@ -81,16 +85,19 @@
-p 127.0.0.1:8080:8080 \
headscale/headscale:x.x.x headscale serve
```
+
## Nodes configuration
If you used tailscale.com before in your nodes, make sure you clear the tailscaled data folder
- ```shell
- systemctl stop tailscaled
- rm -fr /var/lib/tailscale
- systemctl start tailscaled
- ```
+```shell
+systemctl stop tailscaled
+rm -fr /var/lib/tailscale
+systemctl start tailscaled
+```
+
### Adding node based on MACHINEKEY
+
1. Add your first machine
```shell
diff --git a/grpcv1.go b/grpcv1.go
index 998f0c0..40419c3 100644
--- a/grpcv1.go
+++ b/grpcv1.go
@@ -155,7 +155,10 @@ func (api headscaleV1APIServer) RegisterMachine(
ctx context.Context,
request *v1.RegisterMachineRequest,
) (*v1.RegisterMachineResponse, error) {
- log.Trace().Str("namespace", request.GetNamespace()).Str("machine_key", request.GetKey()).Msg("Registering machine")
+ log.Trace().
+ Str("namespace", request.GetNamespace()).
+ Str("machine_key", request.GetKey()).
+ Msg("Registering machine")
machine, err := api.h.RegisterMachine(
request.GetKey(),
request.GetNamespace(),
@@ -208,7 +211,9 @@ func (api headscaleV1APIServer) ListMachines(
return nil, err
}
- sharedMachines, err := api.h.ListSharedMachinesInNamespace(request.GetNamespace())
+ sharedMachines, err := api.h.ListSharedMachinesInNamespace(
+ request.GetNamespace(),
+ )
if err != nil {
return nil, err
}
@@ -333,12 +338,16 @@ func (api headscaleV1APIServer) DebugCreateMachine(
return nil, err
}
- routes, err := stringToIpPrefix(request.GetRoutes())
+ routes, err := stringToIPPrefix(request.GetRoutes())
if err != nil {
return nil, err
}
- log.Trace().Caller().Interface("route-prefix", routes).Interface("route-str", request.GetRoutes()).Msg("")
+ log.Trace().
+ Caller().
+ Interface("route-prefix", routes).
+ Interface("route-str", request.GetRoutes()).
+ Msg("")
hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes,
diff --git a/integration_cli_test.go b/integration_cli_test.go
index e1b1672..898e2cd 100644
--- a/integration_cli_test.go
+++ b/integration_cli_test.go
@@ -88,6 +88,7 @@ func (s *IntegrationCLITestSuite) SetupTest() {
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code not OK")
}
+
return nil
}); err != nil {
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
@@ -109,7 +110,10 @@ func (s *IntegrationCLITestSuite) TearDownTest() {
}
}
-func (s *IntegrationCLITestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
+func (s *IntegrationCLITestSuite) HandleStats(
+ suiteName string,
+ stats *suite.SuiteInformation,
+) {
s.stats = stats
}
@@ -144,7 +148,6 @@ func (s *IntegrationCLITestSuite) TestNamespaceCommand() {
namespaces := make([]*v1.Namespace, len(names))
for index, namespaceName := range names {
-
namespace, err := s.createNamespace(namespaceName)
assert.Nil(s.T(), err)
@@ -298,11 +301,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().After(time.Now()))
assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().After(time.Now()))
- assert.True(s.T(), listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
- assert.True(s.T(), listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
- assert.True(s.T(), listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
- assert.True(s.T(), listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
- assert.True(s.T(), listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)))
+ assert.True(
+ s.T(),
+ listedPreAuthKeys[0].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
+ )
+ assert.True(
+ s.T(),
+ listedPreAuthKeys[1].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
+ )
+ assert.True(
+ s.T(),
+ listedPreAuthKeys[2].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
+ )
+ assert.True(
+ s.T(),
+ listedPreAuthKeys[3].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
+ )
+ assert.True(
+ s.T(),
+ listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)),
+ )
// Expire three keys
for i := 0; i < 3; i++ {
@@ -341,11 +359,26 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() {
err = json.Unmarshal([]byte(listAfterExpireResult), &listedAfterExpirePreAuthKeys)
assert.Nil(s.T(), err)
- assert.True(s.T(), listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()))
- assert.True(s.T(), listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()))
- assert.True(s.T(), listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()))
- assert.True(s.T(), listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()))
- assert.True(s.T(), listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()))
+ assert.True(
+ s.T(),
+ listedAfterExpirePreAuthKeys[0].Expiration.AsTime().Before(time.Now()),
+ )
+ assert.True(
+ s.T(),
+ listedAfterExpirePreAuthKeys[1].Expiration.AsTime().Before(time.Now()),
+ )
+ assert.True(
+ s.T(),
+ listedAfterExpirePreAuthKeys[2].Expiration.AsTime().Before(time.Now()),
+ )
+ assert.True(
+ s.T(),
+ listedAfterExpirePreAuthKeys[3].Expiration.AsTime().After(time.Now()),
+ )
+ assert.True(
+ s.T(),
+ listedAfterExpirePreAuthKeys[4].Expiration.AsTime().After(time.Now()),
+ )
}
func (s *IntegrationCLITestSuite) TestPreAuthKeyCommandWithoutExpiry() {
@@ -689,7 +722,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlySharedMachineNamespace []v1.Machine
- err = json.Unmarshal([]byte(listOnlySharedMachineNamespaceResult), &listOnlySharedMachineNamespace)
+ err = json.Unmarshal(
+ []byte(listOnlySharedMachineNamespaceResult),
+ &listOnlySharedMachineNamespace,
+ )
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlySharedMachineNamespace, 2)
@@ -738,7 +774,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterDelete []v1.Machine
- err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterDeleteResult), &listOnlyMachineNamespaceAfterDelete)
+ err = json.Unmarshal(
+ []byte(listOnlyMachineNamespaceAfterDeleteResult),
+ &listOnlyMachineNamespaceAfterDelete,
+ )
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterDelete, 4)
@@ -789,7 +828,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterShare []v1.Machine
- err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterShareResult), &listOnlyMachineNamespaceAfterShare)
+ err = json.Unmarshal(
+ []byte(listOnlyMachineNamespaceAfterShareResult),
+ &listOnlyMachineNamespaceAfterShare,
+ )
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterShare, 5)
@@ -846,7 +888,10 @@ func (s *IntegrationCLITestSuite) TestNodeCommand() {
assert.Nil(s.T(), err)
var listOnlyMachineNamespaceAfterUnshare []v1.Machine
- err = json.Unmarshal([]byte(listOnlyMachineNamespaceAfterUnshareResult), &listOnlyMachineNamespaceAfterUnshare)
+ err = json.Unmarshal(
+ []byte(listOnlyMachineNamespaceAfterUnshareResult),
+ &listOnlyMachineNamespaceAfterUnshare,
+ )
assert.Nil(s.T(), err)
assert.Len(s.T(), listOnlyMachineNamespaceAfterUnshare, 4)
@@ -1010,5 +1055,9 @@ func (s *IntegrationCLITestSuite) TestRouteCommand() {
)
assert.Nil(s.T(), err)
- assert.Contains(s.T(), string(failEnableNonAdvertisedRoute), "route (route-machine) is not available on node")
+ assert.Contains(
+ s.T(),
+ string(failEnableNonAdvertisedRoute),
+ "route (route-machine) is not available on node",
+ )
}
diff --git a/integration_common_test.go b/integration_common_test.go
index 71d4866..31bae51 100644
--- a/integration_common_test.go
+++ b/integration_common_test.go
@@ -12,12 +12,18 @@ import (
"github.com/ory/dockertest/v3/docker"
)
-func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (string, error) {
+const DOCKER_EXECUTE_TIMEOUT = 10 * time.Second
+
+func ExecuteCommand(
+ resource *dockertest.Resource,
+ cmd []string,
+ env []string,
+) (string, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer
// TODO(kradalby): Make configurable
- timeout := 10 * time.Second
+ timeout := DOCKER_EXECUTE_TIMEOUT
type result struct {
exitCode int
@@ -51,11 +57,13 @@ func ExecuteCommand(resource *dockertest.Resource, cmd []string, env []string) (
fmt.Println("Command: ", cmd)
fmt.Println("stdout: ", stdout.String())
fmt.Println("stderr: ", stderr.String())
+
return "", fmt.Errorf("command failed with: %s", stderr.String())
}
return stdout.String(), nil
case <-time.After(timeout):
+
return "", fmt.Errorf("command timed out after %s", timeout)
}
}
diff --git a/integration_test.go b/integration_test.go
index c4f6bb4..a9125a0 100644
--- a/integration_test.go
+++ b/integration_test.go
@@ -23,10 +23,9 @@ import (
"github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
+ "inet.af/netaddr"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn/ipnstate"
-
- "inet.af/netaddr"
)
var tailscaleVersions = []string{"1.16.2", "1.14.3", "1.12.3"}
@@ -89,7 +88,10 @@ func TestIntegrationTestSuite(t *testing.T) {
}
}
-func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath string) error {
+func (s *IntegrationTestSuite) saveLog(
+ resource *dockertest.Resource,
+ basePath string,
+) error {
err := os.MkdirAll(basePath, os.ModePerm)
if err != nil {
return err
@@ -118,12 +120,20 @@ func (s *IntegrationTestSuite) saveLog(resource *dockertest.Resource, basePath s
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
- err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0o644)
+ err = ioutil.WriteFile(
+ path.Join(basePath, resource.Container.Name+".stdout.log"),
+ []byte(stdout.String()),
+ 0o644,
+ )
if err != nil {
return err
}
- err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0o644)
+ err = ioutil.WriteFile(
+ path.Join(basePath, resource.Container.Name+".stderr.log"),
+ []byte(stdout.String()),
+ 0o644,
+ )
if err != nil {
return err
}
@@ -144,24 +154,38 @@ func (s *IntegrationTestSuite) tailscaleContainer(
},
},
}
- hostname := fmt.Sprintf("%s-tailscale-%s-%s", namespace, strings.Replace(version, ".", "-", -1), identifier)
+ hostname := fmt.Sprintf(
+ "%s-tailscale-%s-%s",
+ namespace,
+ strings.Replace(version, ".", "-", -1),
+ identifier,
+ )
tailscaleOptions := &dockertest.RunOptions{
Name: hostname,
Networks: []*dockertest.Network{&s.network},
- Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
+ Cmd: []string{
+ "tailscaled",
+ "--tun=userspace-networking",
+ "--socks5-server=localhost:1055",
+ },
}
- pts, err := s.pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, DockerRestartPolicy)
+ pts, err := s.pool.BuildAndRunWithBuildOptions(
+ tailscaleBuildOptions,
+ tailscaleOptions,
+ DockerRestartPolicy,
+ )
if err != nil {
log.Fatalf("Could not start resource: %s", err)
}
fmt.Printf("Created %s container\n", hostname)
+
return hostname, pts
}
func (s *IntegrationTestSuite) SetupSuite() {
var err error
- h = Headscale{
+ app = Headscale{
dbType: "sqlite3",
dbString: "integration_test_db.sqlite3",
}
@@ -210,7 +234,11 @@ func (s *IntegrationTestSuite) SetupSuite() {
for i := 0; i < scales.count; i++ {
version := tailscaleVersions[i%len(tailscaleVersions)]
- hostname, container := s.tailscaleContainer(namespace, fmt.Sprint(i), version)
+ hostname, container := s.tailscaleContainer(
+ namespace,
+ fmt.Sprint(i),
+ version,
+ )
scales.tailscales[hostname] = *container
}
}
@@ -220,13 +248,16 @@ func (s *IntegrationTestSuite) SetupSuite() {
if err := s.pool.Retry(func() error {
url := fmt.Sprintf("http://%s/health", hostEndpoint)
+
resp, err := http.Get(url)
if err != nil {
return err
}
+
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code not OK")
}
+
return nil
}); err != nil {
// TODO(kradalby): If we cannot access headscale, or any other fatal error during
@@ -273,7 +304,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
headscaleEndpoint := "http://headscale:8080"
- fmt.Printf("Joining tailscale containers to headscale at %s\n", headscaleEndpoint)
+ fmt.Printf(
+ "Joining tailscale containers to headscale at %s\n",
+ headscaleEndpoint,
+ )
for hostname, tailscale := range scales.tailscales {
command := []string{
"tailscale",
@@ -307,7 +341,10 @@ func (s *IntegrationTestSuite) SetupSuite() {
func (s *IntegrationTestSuite) TearDownSuite() {
}
-func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
+func (s *IntegrationTestSuite) HandleStats(
+ suiteName string,
+ stats *suite.SuiteInformation,
+) {
s.stats = stats
}
@@ -427,7 +464,13 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
ip.String(),
}
- fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
+ fmt.Printf(
+ "Pinging from %s (%s) to %s (%s)\n",
+ hostname,
+ ips[hostname],
+ peername,
+ ip,
+ )
result, err := ExecuteCommand(
&tailscale,
command,
@@ -449,7 +492,15 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
result, err := ExecuteCommand(
&s.headscale,
- []string{"headscale", "nodes", "list", "--output", "json", "--namespace", "shared"},
+ []string{
+ "headscale",
+ "nodes",
+ "list",
+ "--output",
+ "json",
+ "--namespace",
+ "shared",
+ },
[]string{},
)
assert.Nil(s.T(), err)
@@ -459,7 +510,6 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
assert.Nil(s.T(), err)
for _, machine := range machineList {
-
result, err := ExecuteCommand(
&s.headscale,
[]string{
@@ -520,7 +570,13 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
ip.String(),
}
- fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip)
+ fmt.Printf(
+ "Pinging from %s (%s) to %s (%s)\n",
+ hostname,
+ mainIps[hostname],
+ peername,
+ ip,
+ )
result, err := ExecuteCommand(
&tailscale,
command,
@@ -553,7 +609,6 @@ func (s *IntegrationTestSuite) TestTailDrop() {
for peername, ip := range ips {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
if peername != hostname {
-
// Under normal circumstances, we should be able to send a file
// using `tailscale file cp` - but not in userspace networking mode
// So curl!
@@ -578,9 +633,19 @@ func (s *IntegrationTestSuite) TestTailDrop() {
"PUT",
"--upload-file",
fmt.Sprintf("/tmp/file_from_%s", hostname),
- fmt.Sprintf("%s/v0/put/file_from_%s", peerAPI, hostname),
+ fmt.Sprintf(
+ "%s/v0/put/file_from_%s",
+ peerAPI,
+ hostname,
+ ),
}
- fmt.Printf("Sending file from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
+ fmt.Printf(
+ "Sending file from %s (%s) to %s (%s)\n",
+ hostname,
+ ips[hostname],
+ peername,
+ ip,
+ )
_, err = ExecuteCommand(
&tailscale,
command,
@@ -621,7 +686,13 @@ func (s *IntegrationTestSuite) TestTailDrop() {
"ls",
fmt.Sprintf("/tmp/file_from_%s", peername),
}
- fmt.Printf("Checking file in %s (%s) from %s (%s)\n", hostname, ips[hostname], peername, ip)
+ fmt.Printf(
+ "Checking file in %s (%s) from %s (%s)\n",
+ hostname,
+ ips[hostname],
+ peername,
+ ip,
+ )
result, err := ExecuteCommand(
&tailscale,
command,
@@ -629,7 +700,11 @@ func (s *IntegrationTestSuite) TestTailDrop() {
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", peername, result)
- assert.Equal(t, result, fmt.Sprintf("/tmp/file_from_%s\n", peername))
+ assert.Equal(
+ t,
+ result,
+ fmt.Sprintf("/tmp/file_from_%s\n", peername),
+ )
}
})
}
@@ -696,10 +771,13 @@ func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, e
ips[hostname] = ip
}
+
return ips, nil
}
-func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]string, error) {
+func getAPIURLs(
+ tailscales map[string]dockertest.Resource,
+) (map[netaddr.IP]string, error) {
fts := make(map[netaddr.IP]string)
for _, tailscale := range tailscales {
command := []string{
@@ -733,5 +811,6 @@ func getAPIURLs(tailscales map[string]dockertest.Resource) (map[netaddr.IP]strin
}
}
}
+
return fts, nil
}
diff --git a/k8s/README.md b/k8s/README.md
index 45574b4..78e9ef2 100644
--- a/k8s/README.md
+++ b/k8s/README.md
@@ -24,6 +24,7 @@ Configure DERP servers by editing `base/site/derp.yaml` if needed.
You'll somehow need to get `headscale:latest` into your cluster image registry.
An easy way to do this with k3s:
+
- Reconfigure k3s to use docker instead of containerd (`k3s server --docker`)
- `docker build -t headscale:latest ..` from here
@@ -61,7 +62,7 @@ Use the wrapper script to remotely operate headscale to perform administrative
tasks like creating namespaces, authkeys, etc.
```
-[c@nix-slate:~/Projects/headscale/k8s]$ ./headscale.bash
+[c@nix-slate:~/Projects/headscale/k8s]$ ./headscale.bash
headscale is an open source implementation of the Tailscale control server
diff --git a/k8s/base/ingress.yaml b/k8s/base/ingress.yaml
index a279bc1..51da342 100644
--- a/k8s/base/ingress.yaml
+++ b/k8s/base/ingress.yaml
@@ -6,13 +6,13 @@ metadata:
kubernetes.io/ingress.class: traefik
spec:
rules:
- - host: $(PUBLIC_HOSTNAME)
- http:
- paths:
- - backend:
- service:
- name: headscale
- port:
- number: 8080
- path: /
- pathType: Prefix
+ - host: $(PUBLIC_HOSTNAME)
+ http:
+ paths:
+ - backend:
+ service:
+ name: headscale
+ port:
+ number: 8080
+ path: /
+ pathType: Prefix
diff --git a/k8s/base/kustomization.yaml b/k8s/base/kustomization.yaml
index 54d66e5..93278f7 100644
--- a/k8s/base/kustomization.yaml
+++ b/k8s/base/kustomization.yaml
@@ -1,42 +1,42 @@
namespace: headscale
resources:
-- configmap.yaml
-- ingress.yaml
-- service.yaml
+ - configmap.yaml
+ - ingress.yaml
+ - service.yaml
generatorOptions:
disableNameSuffixHash: true
configMapGenerator:
-- name: headscale-site
- files:
- - derp.yaml=site/derp.yaml
- envs:
- - site/public.env
-- name: headscale-etc
- literals:
- - config.json={}
+ - name: headscale-site
+ files:
+ - derp.yaml=site/derp.yaml
+ envs:
+ - site/public.env
+ - name: headscale-etc
+ literals:
+ - config.json={}
secretGenerator:
-- name: headscale
- files:
- - secrets/private-key
+ - name: headscale
+ files:
+ - secrets/private-key
vars:
-- name: PUBLIC_PROTO
- objRef:
- kind: ConfigMap
- name: headscale-site
- apiVersion: v1
- fieldRef:
- fieldPath: data.public-proto
-- name: PUBLIC_HOSTNAME
- objRef:
- kind: ConfigMap
- name: headscale-site
- apiVersion: v1
- fieldRef:
- fieldPath: data.public-hostname
-- name: CONTACT_EMAIL
- objRef:
- kind: ConfigMap
- name: headscale-site
- apiVersion: v1
- fieldRef:
- fieldPath: data.contact-email
+ - name: PUBLIC_PROTO
+ objRef:
+ kind: ConfigMap
+ name: headscale-site
+ apiVersion: v1
+ fieldRef:
+ fieldPath: data.public-proto
+ - name: PUBLIC_HOSTNAME
+ objRef:
+ kind: ConfigMap
+ name: headscale-site
+ apiVersion: v1
+ fieldRef:
+ fieldPath: data.public-hostname
+ - name: CONTACT_EMAIL
+ objRef:
+ kind: ConfigMap
+ name: headscale-site
+ apiVersion: v1
+ fieldRef:
+ fieldPath: data.contact-email
diff --git a/k8s/base/service.yaml b/k8s/base/service.yaml
index 7fdf738..39e6725 100644
--- a/k8s/base/service.yaml
+++ b/k8s/base/service.yaml
@@ -8,6 +8,6 @@ spec:
selector:
app: headscale
ports:
- - name: http
- targetPort: http
- port: 8080
+ - name: http
+ targetPort: http
+ port: 8080
diff --git a/k8s/postgres/deployment.yaml b/k8s/postgres/deployment.yaml
index dd45d05..661d87e 100644
--- a/k8s/postgres/deployment.yaml
+++ b/k8s/postgres/deployment.yaml
@@ -13,66 +13,66 @@ spec:
app: headscale
spec:
containers:
- - name: headscale
- image: "headscale:latest"
- imagePullPolicy: IfNotPresent
- command: ["/go/bin/headscale", "serve"]
- env:
- - name: SERVER_URL
- value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME)
- - name: LISTEN_ADDR
- valueFrom:
- configMapKeyRef:
- name: headscale-config
- key: listen_addr
- - name: PRIVATE_KEY_PATH
- value: /vol/secret/private-key
- - name: DERP_MAP_PATH
- value: /vol/config/derp.yaml
- - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT
- valueFrom:
- configMapKeyRef:
- name: headscale-config
- key: ephemeral_node_inactivity_timeout
- - name: DB_TYPE
- value: postgres
- - name: DB_HOST
- value: postgres.headscale.svc.cluster.local
- - name: DB_PORT
- value: "5432"
- - name: DB_USER
- value: headscale
- - name: DB_PASS
- valueFrom:
- secretKeyRef:
- name: postgresql
- key: password
- - name: DB_NAME
- value: headscale
- ports:
- - name: http
- protocol: TCP
- containerPort: 8080
- livenessProbe:
- tcpSocket:
- port: http
- initialDelaySeconds: 30
- timeoutSeconds: 5
- periodSeconds: 15
- volumeMounts:
- - name: config
- mountPath: /vol/config
- - name: secret
- mountPath: /vol/secret
- - name: etc
- mountPath: /etc/headscale
+ - name: headscale
+ image: "headscale:latest"
+ imagePullPolicy: IfNotPresent
+ command: ["/go/bin/headscale", "serve"]
+ env:
+ - name: SERVER_URL
+ value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME)
+ - name: LISTEN_ADDR
+ valueFrom:
+ configMapKeyRef:
+ name: headscale-config
+ key: listen_addr
+ - name: PRIVATE_KEY_PATH
+ value: /vol/secret/private-key
+ - name: DERP_MAP_PATH
+ value: /vol/config/derp.yaml
+ - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT
+ valueFrom:
+ configMapKeyRef:
+ name: headscale-config
+ key: ephemeral_node_inactivity_timeout
+ - name: DB_TYPE
+ value: postgres
+ - name: DB_HOST
+ value: postgres.headscale.svc.cluster.local
+ - name: DB_PORT
+ value: "5432"
+ - name: DB_USER
+ value: headscale
+ - name: DB_PASS
+ valueFrom:
+ secretKeyRef:
+ name: postgresql
+ key: password
+ - name: DB_NAME
+ value: headscale
+ ports:
+ - name: http
+ protocol: TCP
+ containerPort: 8080
+ livenessProbe:
+ tcpSocket:
+ port: http
+ initialDelaySeconds: 30
+ timeoutSeconds: 5
+ periodSeconds: 15
+ volumeMounts:
+ - name: config
+ mountPath: /vol/config
+ - name: secret
+ mountPath: /vol/secret
+ - name: etc
+ mountPath: /etc/headscale
volumes:
- - name: config
- configMap:
- name: headscale-site
- - name: etc
- configMap:
- name: headscale-etc
- - name: secret
- secret:
- secretName: headscale
+ - name: config
+ configMap:
+ name: headscale-site
+ - name: etc
+ configMap:
+ name: headscale-etc
+ - name: secret
+ secret:
+ secretName: headscale
diff --git a/k8s/postgres/kustomization.yaml b/k8s/postgres/kustomization.yaml
index 8bd6c40..e732e3b 100644
--- a/k8s/postgres/kustomization.yaml
+++ b/k8s/postgres/kustomization.yaml
@@ -1,13 +1,13 @@
namespace: headscale
bases:
-- ../base
+ - ../base
resources:
-- deployment.yaml
-- postgres-service.yaml
-- postgres-statefulset.yaml
+ - deployment.yaml
+ - postgres-service.yaml
+ - postgres-statefulset.yaml
generatorOptions:
disableNameSuffixHash: true
secretGenerator:
-- name: postgresql
- files:
- - secrets/password
+ - name: postgresql
+ files:
+ - secrets/password
diff --git a/k8s/postgres/postgres-service.yaml b/k8s/postgres/postgres-service.yaml
index e2f486c..6252e7f 100644
--- a/k8s/postgres/postgres-service.yaml
+++ b/k8s/postgres/postgres-service.yaml
@@ -8,6 +8,6 @@ spec:
selector:
app: postgres
ports:
- - name: postgres
- targetPort: postgres
- port: 5432
+ - name: postgres
+ targetPort: postgres
+ port: 5432
diff --git a/k8s/postgres/postgres-statefulset.yaml b/k8s/postgres/postgres-statefulset.yaml
index 25285c5..b81c9bf 100644
--- a/k8s/postgres/postgres-statefulset.yaml
+++ b/k8s/postgres/postgres-statefulset.yaml
@@ -14,36 +14,36 @@ spec:
app: postgres
spec:
containers:
- - name: postgres
- image: "postgres:13"
- imagePullPolicy: IfNotPresent
- env:
- - name: POSTGRES_PASSWORD
- valueFrom:
- secretKeyRef:
- name: postgresql
- key: password
- - name: POSTGRES_USER
- value: headscale
- ports:
- name: postgres
- protocol: TCP
- containerPort: 5432
- livenessProbe:
- tcpSocket:
- port: 5432
- initialDelaySeconds: 30
- timeoutSeconds: 5
- periodSeconds: 15
- volumeMounts:
- - name: pgdata
- mountPath: /var/lib/postgresql/data
+ image: "postgres:13"
+ imagePullPolicy: IfNotPresent
+ env:
+ - name: POSTGRES_PASSWORD
+ valueFrom:
+ secretKeyRef:
+ name: postgresql
+ key: password
+ - name: POSTGRES_USER
+ value: headscale
+ ports:
+ - name: postgres
+ protocol: TCP
+ containerPort: 5432
+ livenessProbe:
+ tcpSocket:
+ port: 5432
+ initialDelaySeconds: 30
+ timeoutSeconds: 5
+ periodSeconds: 15
+ volumeMounts:
+ - name: pgdata
+ mountPath: /var/lib/postgresql/data
volumeClaimTemplates:
- - metadata:
- name: pgdata
- spec:
- storageClassName: local-path
- accessModes: ["ReadWriteOnce"]
- resources:
- requests:
- storage: 1Gi
+ - metadata:
+ name: pgdata
+ spec:
+ storageClassName: local-path
+ accessModes: ["ReadWriteOnce"]
+ resources:
+ requests:
+ storage: 1Gi
diff --git a/k8s/production-tls/ingress-patch.yaml b/k8s/production-tls/ingress-patch.yaml
index 387c736..9e6177f 100644
--- a/k8s/production-tls/ingress-patch.yaml
+++ b/k8s/production-tls/ingress-patch.yaml
@@ -6,6 +6,6 @@ metadata:
traefik.ingress.kubernetes.io/router.tls: "true"
spec:
tls:
- - hosts:
- - $(PUBLIC_HOSTNAME)
- secretName: production-cert
+ - hosts:
+ - $(PUBLIC_HOSTNAME)
+ secretName: production-cert
diff --git a/k8s/production-tls/kustomization.yaml b/k8s/production-tls/kustomization.yaml
index f57cb54..d3147f5 100644
--- a/k8s/production-tls/kustomization.yaml
+++ b/k8s/production-tls/kustomization.yaml
@@ -1,9 +1,9 @@
namespace: headscale
bases:
-- ../base
+ - ../base
resources:
-- production-issuer.yaml
+ - production-issuer.yaml
patches:
-- path: ingress-patch.yaml
- target:
- kind: Ingress
+ - path: ingress-patch.yaml
+ target:
+ kind: Ingress
diff --git a/k8s/production-tls/production-issuer.yaml b/k8s/production-tls/production-issuer.yaml
index 7ae9131..f436090 100644
--- a/k8s/production-tls/production-issuer.yaml
+++ b/k8s/production-tls/production-issuer.yaml
@@ -11,6 +11,6 @@ spec:
# Secret resource used to store the account's private key.
name: letsencrypt-production-acc-key
solvers:
- - http01:
- ingress:
- class: traefik
+ - http01:
+ ingress:
+ class: traefik
diff --git a/k8s/sqlite/kustomization.yaml b/k8s/sqlite/kustomization.yaml
index 5be451c..ca79941 100644
--- a/k8s/sqlite/kustomization.yaml
+++ b/k8s/sqlite/kustomization.yaml
@@ -1,5 +1,5 @@
namespace: headscale
bases:
-- ../base
+ - ../base
resources:
-- statefulset.yaml
+ - statefulset.yaml
diff --git a/k8s/sqlite/statefulset.yaml b/k8s/sqlite/statefulset.yaml
index 9075e00..71077da 100644
--- a/k8s/sqlite/statefulset.yaml
+++ b/k8s/sqlite/statefulset.yaml
@@ -14,66 +14,66 @@ spec:
app: headscale
spec:
containers:
- - name: headscale
- image: "headscale:latest"
- imagePullPolicy: IfNotPresent
- command: ["/go/bin/headscale", "serve"]
- env:
- - name: SERVER_URL
- value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME)
- - name: LISTEN_ADDR
- valueFrom:
- configMapKeyRef:
- name: headscale-config
- key: listen_addr
- - name: PRIVATE_KEY_PATH
- value: /vol/secret/private-key
- - name: DERP_MAP_PATH
- value: /vol/config/derp.yaml
- - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT
- valueFrom:
- configMapKeyRef:
- name: headscale-config
- key: ephemeral_node_inactivity_timeout
- - name: DB_TYPE
- value: sqlite3
- - name: DB_PATH
- value: /vol/data/db.sqlite
- ports:
- - name: http
- protocol: TCP
- containerPort: 8080
- livenessProbe:
- tcpSocket:
- port: http
- initialDelaySeconds: 30
- timeoutSeconds: 5
- periodSeconds: 15
- volumeMounts:
- - name: config
- mountPath: /vol/config
- - name: data
- mountPath: /vol/data
- - name: secret
- mountPath: /vol/secret
- - name: etc
- mountPath: /etc/headscale
+ - name: headscale
+ image: "headscale:latest"
+ imagePullPolicy: IfNotPresent
+ command: ["/go/bin/headscale", "serve"]
+ env:
+ - name: SERVER_URL
+ value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME)
+ - name: LISTEN_ADDR
+ valueFrom:
+ configMapKeyRef:
+ name: headscale-config
+ key: listen_addr
+ - name: PRIVATE_KEY_PATH
+ value: /vol/secret/private-key
+ - name: DERP_MAP_PATH
+ value: /vol/config/derp.yaml
+ - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT
+ valueFrom:
+ configMapKeyRef:
+ name: headscale-config
+ key: ephemeral_node_inactivity_timeout
+ - name: DB_TYPE
+ value: sqlite3
+ - name: DB_PATH
+ value: /vol/data/db.sqlite
+ ports:
+ - name: http
+ protocol: TCP
+ containerPort: 8080
+ livenessProbe:
+ tcpSocket:
+ port: http
+ initialDelaySeconds: 30
+ timeoutSeconds: 5
+ periodSeconds: 15
+ volumeMounts:
+ - name: config
+ mountPath: /vol/config
+ - name: data
+ mountPath: /vol/data
+ - name: secret
+ mountPath: /vol/secret
+ - name: etc
+ mountPath: /etc/headscale
volumes:
- - name: config
- configMap:
- name: headscale-site
- - name: etc
- configMap:
- name: headscale-etc
- - name: secret
- secret:
- secretName: headscale
+ - name: config
+ configMap:
+ name: headscale-site
+ - name: etc
+ configMap:
+ name: headscale-etc
+ - name: secret
+ secret:
+ secretName: headscale
volumeClaimTemplates:
- - metadata:
- name: data
- spec:
- storageClassName: local-path
- accessModes: ["ReadWriteOnce"]
- resources:
- requests:
- storage: 1Gi
+ - metadata:
+ name: data
+ spec:
+ storageClassName: local-path
+ accessModes: ["ReadWriteOnce"]
+ resources:
+ requests:
+ storage: 1Gi
diff --git a/k8s/staging-tls/ingress-patch.yaml b/k8s/staging-tls/ingress-patch.yaml
index f97974b..5a1daf0 100644
--- a/k8s/staging-tls/ingress-patch.yaml
+++ b/k8s/staging-tls/ingress-patch.yaml
@@ -6,6 +6,6 @@ metadata:
traefik.ingress.kubernetes.io/router.tls: "true"
spec:
tls:
- - hosts:
- - $(PUBLIC_HOSTNAME)
- secretName: staging-cert
+ - hosts:
+ - $(PUBLIC_HOSTNAME)
+ secretName: staging-cert
diff --git a/k8s/staging-tls/kustomization.yaml b/k8s/staging-tls/kustomization.yaml
index 931f27d..0900c58 100644
--- a/k8s/staging-tls/kustomization.yaml
+++ b/k8s/staging-tls/kustomization.yaml
@@ -1,9 +1,9 @@
namespace: headscale
bases:
-- ../base
+ - ../base
resources:
-- staging-issuer.yaml
+ - staging-issuer.yaml
patches:
-- path: ingress-patch.yaml
- target:
- kind: Ingress
+ - path: ingress-patch.yaml
+ target:
+ kind: Ingress
diff --git a/k8s/staging-tls/staging-issuer.yaml b/k8s/staging-tls/staging-issuer.yaml
index 95325f6..cf29041 100644
--- a/k8s/staging-tls/staging-issuer.yaml
+++ b/k8s/staging-tls/staging-issuer.yaml
@@ -11,6 +11,6 @@ spec:
# Secret resource used to store the account's private key.
name: letsencrypt-staging-acc-key
solvers:
- - http01:
- ingress:
- class: traefik
+ - http01:
+ ingress:
+ class: traefik
diff --git a/machine.go b/machine.go
index 557ab5b..dcd7970 100644
--- a/machine.go
+++ b/machine.go
@@ -10,10 +10,9 @@ import (
"time"
"github.com/fatih/set"
+ v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
-
- v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr"
@@ -21,7 +20,13 @@ import (
"tailscale.com/types/wgkey"
)
-// Machine is a Headscale client
+const (
+ errMachineNotFound = Error("machine not found")
+ errMachineAlreadyRegistered = Error("machine already registered")
+ errMachineRouteIsNotAvailable = Error("route is not available on machine")
+)
+
+// Machine is a Headscale client.
type Machine struct {
ID uint64 `gorm:"primary_key"`
MachineKey string `gorm:"type:varchar(64);unique_index"`
@@ -56,53 +61,58 @@ type (
MachinesP []*Machine
)
-// For the time being this method is rather naive
-func (m Machine) isAlreadyRegistered() bool {
- return m.Registered
+// For the time being this method is rather naive.
+func (machine Machine) isAlreadyRegistered() bool {
+ return machine.Registered
}
-// isExpired returns whether the machine registration has expired
-func (m Machine) isExpired() bool {
- return time.Now().UTC().After(*m.Expiry)
+// isExpired returns whether the machine registration has expired.
+func (machine Machine) isExpired() bool {
+ return time.Now().UTC().After(*machine.Expiry)
}
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
// expiry time.
-func (h *Headscale) updateMachineExpiry(m *Machine) {
- if m.isExpired() {
+func (h *Headscale) updateMachineExpiry(machine *Machine) {
+ if machine.isExpired() {
now := time.Now().UTC()
- maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry
- defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry
+ maxExpiry := now.Add(
+ h.cfg.MaxMachineRegistrationDuration,
+ ) // calculate the maximum expiry
+ defaultExpiry := now.Add(
+ h.cfg.DefaultMachineRegistrationDuration,
+ ) // calculate the default expiry
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
- if maxExpiry.Before(*m.RequestedExpiry) {
+ if maxExpiry.Before(*machine.RequestedExpiry) {
log.Debug().
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
- m.Expiry = &maxExpiry
- } else if m.RequestedExpiry.IsZero() {
+ machine.Expiry = &maxExpiry
+ } else if machine.RequestedExpiry.IsZero() {
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
- m.Expiry = &defaultExpiry
+ machine.Expiry = &defaultExpiry
} else {
- log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
- m.Expiry = m.RequestedExpiry
+ log.Debug().Msgf("Using requested machine registration expiry time: %v", machine.RequestedExpiry)
+ machine.Expiry = machine.RequestedExpiry
}
- h.db.Save(&m)
+ h.db.Save(&machine)
}
}
-func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
+func (h *Headscale) getDirectPeers(machine *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Finding direct peers")
machines := Machines{}
if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
- m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
+ machine.NamespaceID, machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
+
return Machines{}, err
}
@@ -110,21 +120,22 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("Found direct machines: %s", machines.String())
+
return machines, nil
}
-// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for
-func (h *Headscale) getShared(m *Machine) (Machines, error) {
+// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for.
+func (h *Headscale) getShared(machine *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Finding shared peers")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("namespace_id = ?",
- m.NamespaceID).Find(&sharedMachines).Error; err != nil {
+ machine.NamespaceID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
@@ -137,27 +148,30 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("Found shared peers: %s", peers.String())
+
return peers, nil
}
-// getSharedTo fetches the machines of the namespaces this machine is shared in
-func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
+// getSharedTo fetches the machines of the namespaces this machine is shared in.
+func (h *Headscale) getSharedTo(machine *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Finding peers in namespaces this machine is shared with")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Preload("Machine.Namespace").Where("machine_id = ?",
- m.ID).Find(&sharedMachines).Error; err != nil {
+ machine.ID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
peers := make(Machines, 0)
for _, sharedMachine := range sharedMachines {
- namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name)
+ namespaceMachines, err := h.ListMachinesInNamespace(
+ sharedMachine.Namespace.Name,
+ )
if err != nil {
return Machines{}, err
}
@@ -168,36 +182,40 @@ func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("Found peers we are shared with: %s", peers.String())
+
return peers, nil
}
-func (h *Headscale) getPeers(m *Machine) (Machines, error) {
- direct, err := h.getDirectPeers(m)
+func (h *Headscale) getPeers(machine *Machine) (Machines, error) {
+ direct, err := h.getDirectPeers(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")
+
return Machines{}, err
}
- shared, err := h.getShared(m)
+ shared, err := h.getShared(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")
+
return Machines{}, err
}
- sharedTo, err := h.getSharedTo(m)
+ sharedTo, err := h.getSharedTo(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot fetch peers")
+
return Machines{}, err
}
@@ -208,7 +226,7 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("Found total peers: %s", peers.String())
return peers, nil
@@ -219,10 +237,11 @@ func (h *Headscale) ListMachines() ([]Machine, error) {
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Find(&machines).Error; err != nil {
return nil, err
}
+
return machines, nil
}
-// GetMachine finds a Machine by name and namespace and returns the Machine struct
+// GetMachine finds a Machine by name and namespace and returns the Machine struct.
func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error) {
machines, err := h.ListMachinesInNamespace(namespace)
if err != nil {
@@ -234,73 +253,77 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
return &m, nil
}
}
- return nil, fmt.Errorf("machine not found")
+
+ return nil, errMachineNotFound
}
-// GetMachineByID finds a Machine by ID and returns the Machine struct
+// GetMachineByID finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
m := Machine{}
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
return nil, result.Error
}
+
return &m, nil
}
-// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct
-func (h *Headscale) GetMachineByMachineKey(mKey string) (*Machine, error) {
+// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
+func (h *Headscale) GetMachineByMachineKey(machineKey string) (*Machine, error) {
m := Machine{}
- if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey); result.Error != nil {
+ if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", machineKey); result.Error != nil {
return nil, result.Error
}
+
return &m, nil
}
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
-func (h *Headscale) UpdateMachine(m *Machine) error {
- if result := h.db.Find(m).First(&m); result.Error != nil {
+func (h *Headscale) UpdateMachine(machine *Machine) error {
+ if result := h.db.Find(machine).First(&machine); result.Error != nil {
return result.Error
}
+
return nil
}
-// DeleteMachine softs deletes a Machine from the database
-func (h *Headscale) DeleteMachine(m *Machine) error {
- err := h.RemoveSharedMachineFromAllNamespaces(m)
- if err != nil && err != errorMachineNotShared {
+// DeleteMachine softs deletes a Machine from the database.
+func (h *Headscale) DeleteMachine(machine *Machine) error {
+ err := h.RemoveSharedMachineFromAllNamespaces(machine)
+ if err != nil && errors.Is(err, errMachineNotShared) {
return err
}
- m.Registered = false
- namespaceID := m.NamespaceID
- h.db.Save(&m) // we mark it as unregistered, just in case
- if err := h.db.Delete(&m).Error; err != nil {
+ machine.Registered = false
+ namespaceID := machine.NamespaceID
+ h.db.Save(&machine) // we mark it as unregistered, just in case
+ if err := h.db.Delete(&machine).Error; err != nil {
return err
}
return h.RequestMapUpdates(namespaceID)
}
-// HardDeleteMachine hard deletes a Machine from the database
-func (h *Headscale) HardDeleteMachine(m *Machine) error {
- err := h.RemoveSharedMachineFromAllNamespaces(m)
- if err != nil && err != errorMachineNotShared {
+// HardDeleteMachine hard deletes a Machine from the database.
+func (h *Headscale) HardDeleteMachine(machine *Machine) error {
+ err := h.RemoveSharedMachineFromAllNamespaces(machine)
+ if err != nil && errors.Is(err, errMachineNotShared) {
return err
}
- namespaceID := m.NamespaceID
- if err := h.db.Unscoped().Delete(&m).Error; err != nil {
+ namespaceID := machine.NamespaceID
+ if err := h.db.Unscoped().Delete(&machine).Error; err != nil {
return err
}
return h.RequestMapUpdates(namespaceID)
}
-// GetHostInfo returns a Hostinfo struct for the machine
-func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
+// GetHostInfo returns a Hostinfo struct for the machine.
+func (machine *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
hostinfo := tailcfg.Hostinfo{}
- if len(m.HostInfo) != 0 {
- hi, err := m.HostInfo.MarshalJSON()
+ if len(machine.HostInfo) != 0 {
+ hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@@ -309,21 +332,21 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
return nil, err
}
}
+
return &hostinfo, nil
}
-func (h *Headscale) isOutdated(m *Machine) bool {
- err := h.UpdateMachine(m)
- if err != nil {
+func (h *Headscale) isOutdated(machine *Machine) bool {
+ if err := h.UpdateMachine(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated.
return true
}
- sharedMachines, _ := h.getShared(m)
+ sharedMachines, _ := h.getShared(machine)
namespaceSet := set.New(set.ThreadSafe)
- namespaceSet.Add(m.Namespace.Name)
+ namespaceSet.Add(machine.Namespace.Name)
// Check if any of our shared namespaces has updates that we have
// not propagated.
@@ -333,27 +356,30 @@ func (h *Headscale) isOutdated(m *Machine) bool {
namespaces := make([]string, namespaceSet.Size())
for index, namespace := range namespaceSet.List() {
- namespaces[index] = namespace.(string)
+ if name, ok := namespace.(string); ok {
+ namespaces[index] = name
+ }
}
lastChange := h.getLastStateChange(namespaces...)
log.Trace().
Caller().
- Str("machine", m.Name).
- Time("last_successful_update", *m.LastSuccessfulUpdate).
+ Str("machine", machine.Name).
+ Time("last_successful_update", *machine.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
- Msgf("Checking if %s is missing updates", m.Name)
- return m.LastSuccessfulUpdate.Before(lastChange)
+ Msgf("Checking if %s is missing updates", machine.Name)
+
+ return machine.LastSuccessfulUpdate.Before(lastChange)
}
-func (m Machine) String() string {
- return m.Name
+func (machine Machine) String() string {
+ return machine.Name
}
-func (ms Machines) String() string {
- temp := make([]string, len(ms))
+func (machines Machines) String() string {
+ temp := make([]string, len(machines))
- for index, machine := range ms {
+ for index, machine := range machines {
temp[index] = machine.Name
}
@@ -361,24 +387,24 @@ func (ms Machines) String() string {
}
// TODO(kradalby): Remove when we have generics...
-func (ms MachinesP) String() string {
- temp := make([]string, len(ms))
+func (machines MachinesP) String() string {
+ temp := make([]string, len(machines))
- for index, machine := range ms {
+ for index, machine := range machines {
temp[index] = machine.Name
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
-func (ms Machines) toNodes(
+func (machines Machines) toNodes(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) ([]*tailcfg.Node, error) {
- nodes := make([]*tailcfg.Node, len(ms))
+ nodes := make([]*tailcfg.Node, len(machines))
- for index, machine := range ms {
+ for index, machine := range machines {
node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
if err != nil {
return nil, err
@@ -391,20 +417,25 @@ func (ms Machines) toNodes(
}
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
-// as per the expected behaviour in the official SaaS
-func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) {
- nKey, err := wgkey.ParseHex(m.NodeKey)
+// as per the expected behaviour in the official SaaS.
+func (machine Machine) toNode(
+ baseDomain string,
+ dnsConfig *tailcfg.DNSConfig,
+ includeRoutes bool,
+) (*tailcfg.Node, error) {
+ nodeKey, err := wgkey.ParseHex(machine.NodeKey)
if err != nil {
return nil, err
}
- mKey, err := wgkey.ParseHex(m.MachineKey)
+
+ machineKey, err := wgkey.ParseHex(machine.MachineKey)
if err != nil {
return nil, err
}
var discoKey tailcfg.DiscoKey
- if m.DiscoKey != "" {
- dKey, err := wgkey.ParseHex(m.DiscoKey)
+ if machine.DiscoKey != "" {
+ dKey, err := wgkey.ParseHex(machine.DiscoKey)
if err != nil {
return nil, err
}
@@ -414,23 +445,27 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
addrs := []netaddr.IPPrefix{}
- ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", m.IPAddress))
+ ip, err := netaddr.ParseIPPrefix(fmt.Sprintf("%s/32", machine.IPAddress))
if err != nil {
log.Trace().
Caller().
- Str("ip", m.IPAddress).
- Msgf("Failed to parse IP Prefix from IP: %s", m.IPAddress)
+ Str("ip", machine.IPAddress).
+ Msgf("Failed to parse IP Prefix from IP: %s", machine.IPAddress)
+
return nil, err
}
addrs = append(addrs, ip) // missing the ipv6 ?
allowedIPs := []netaddr.IPPrefix{}
- allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
+ allowedIPs = append(
+ allowedIPs,
+ ip,
+ ) // we append the node own IP, as it is required by the clients
if includeRoutes {
routesStr := []string{}
- if len(m.EnabledRoutes) != 0 {
- allwIps, err := m.EnabledRoutes.MarshalJSON()
+ if len(machine.EnabledRoutes) != 0 {
+ allwIps, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@@ -450,8 +485,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
endpoints := []string{}
- if len(m.Endpoints) != 0 {
- be, err := m.Endpoints.MarshalJSON()
+ if len(machine.Endpoints) != 0 {
+ be, err := machine.Endpoints.MarshalJSON()
if err != nil {
return nil, err
}
@@ -462,8 +497,8 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
hostinfo := tailcfg.Hostinfo{}
- if len(m.HostInfo) != 0 {
- hi, err := m.HostInfo.MarshalJSON()
+ if len(machine.HostInfo) != 0 {
+ hi, err := machine.HostInfo.MarshalJSON()
if err != nil {
return nil, err
}
@@ -481,29 +516,34 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
}
var keyExpiry time.Time
- if m.Expiry != nil {
- keyExpiry = *m.Expiry
+ if machine.Expiry != nil {
+ keyExpiry = *machine.Expiry
} else {
keyExpiry = time.Time{}
}
var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
- hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
+ hostname = fmt.Sprintf(
+ "%s.%s.%s",
+ machine.Name,
+ machine.Namespace.Name,
+ baseDomain,
+ )
} else {
- hostname = m.Name
+ hostname = machine.Name
}
- n := tailcfg.Node{
- ID: tailcfg.NodeID(m.ID), // this is the actual ID
+ node := tailcfg.Node{
+ ID: tailcfg.NodeID(machine.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
- strconv.FormatUint(m.ID, 10),
+ strconv.FormatUint(machine.ID, Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
- User: tailcfg.UserID(m.NamespaceID),
- Key: tailcfg.NodeKey(nKey),
+ User: tailcfg.UserID(machine.NamespaceID),
+ Key: tailcfg.NodeKey(nodeKey),
KeyExpiry: keyExpiry,
- Machine: tailcfg.MachineKey(mKey),
+ Machine: tailcfg.MachineKey(machineKey),
DiscoKey: discoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
@@ -511,81 +551,90 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
DERP: derp,
Hostinfo: hostinfo,
- Created: m.CreatedAt,
- LastSeen: m.LastSeen,
+ Created: machine.CreatedAt,
+ LastSeen: machine.LastSeen,
KeepAlive: true,
- MachineAuthorized: m.Registered,
+ MachineAuthorized: machine.Registered,
Capabilities: []string{tailcfg.CapabilityFileSharing},
}
- return &n, nil
+
+ return &node, nil
}
-func (m *Machine) toProto() *v1.Machine {
- machine := &v1.Machine{
- Id: m.ID,
- MachineKey: m.MachineKey,
+func (machine *Machine) toProto() *v1.Machine {
+ machineProto := &v1.Machine{
+ Id: machine.ID,
+ MachineKey: machine.MachineKey,
- NodeKey: m.NodeKey,
- DiscoKey: m.DiscoKey,
- IpAddress: m.IPAddress,
- Name: m.Name,
- Namespace: m.Namespace.toProto(),
+ NodeKey: machine.NodeKey,
+ DiscoKey: machine.DiscoKey,
+ IpAddress: machine.IPAddress,
+ Name: machine.Name,
+ Namespace: machine.Namespace.toProto(),
- Registered: m.Registered,
+ Registered: machine.Registered,
// TODO(kradalby): Implement register method enum converter
// RegisterMethod: ,
- CreatedAt: timestamppb.New(m.CreatedAt),
+ CreatedAt: timestamppb.New(machine.CreatedAt),
}
- if m.AuthKey != nil {
- machine.PreAuthKey = m.AuthKey.toProto()
+ if machine.AuthKey != nil {
+ machineProto.PreAuthKey = machine.AuthKey.toProto()
}
- if m.LastSeen != nil {
- machine.LastSeen = timestamppb.New(*m.LastSeen)
+ if machine.LastSeen != nil {
+ machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
}
- if m.LastSuccessfulUpdate != nil {
- machine.LastSuccessfulUpdate = timestamppb.New(*m.LastSuccessfulUpdate)
+ if machine.LastSuccessfulUpdate != nil {
+ machineProto.LastSuccessfulUpdate = timestamppb.New(
+ *machine.LastSuccessfulUpdate,
+ )
}
- if m.Expiry != nil {
- machine.Expiry = timestamppb.New(*m.Expiry)
+ if machine.Expiry != nil {
+ machineProto.Expiry = timestamppb.New(*machine.Expiry)
}
- return machine
+ return machineProto
}
-// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey
-func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, error) {
- ns, err := h.GetNamespace(namespace)
+// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
+func (h *Headscale) RegisterMachine(
+ key string,
+ namespaceName string,
+) (*Machine, error) {
+ namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
- mKey, err := wgkey.ParseHex(key)
+ machineKey, err := wgkey.ParseHex(key)
if err != nil {
return nil, err
}
- m := Machine{}
- if result := h.db.First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, errors.New("Machine not found")
+ machine := Machine{}
+ if result := h.db.First(&machine, "machine_key = ?", machineKey.HexString()); errors.Is(
+ result.Error,
+ gorm.ErrRecordNotFound,
+ ) {
+ return nil, errMachineNotFound
}
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Attempting to register machine")
- if m.isAlreadyRegistered() {
- err := errors.New("Machine already registered")
+ if machine.isAlreadyRegistered() {
+ err := errMachineAlreadyRegistered
log.Error().
Caller().
Err(err).
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Attempting to register machine")
return nil, err
@@ -596,42 +645,44 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
log.Error().
Caller().
Err(err).
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Could not find IP for the new machine")
+
return nil, err
}
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Found IP for host")
- m.IPAddress = ip.String()
- m.NamespaceID = ns.ID
- m.Registered = true
- m.RegisterMethod = "cli"
- h.db.Save(&m)
+ machine.IPAddress = ip.String()
+ machine.NamespaceID = namespace.ID
+ machine.Registered = true
+ machine.RegisterMethod = "cli"
+ h.db.Save(&machine)
log.Trace().
Caller().
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("ip", ip.String()).
Msg("Machine registered with the database")
- return &m, nil
+ return &machine, nil
}
-func (m *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
- hostInfo, err := m.GetHostInfo()
+func (machine *Machine) GetAdvertisedRoutes() ([]netaddr.IPPrefix, error) {
+ hostInfo, err := machine.GetHostInfo()
if err != nil {
return nil, err
}
+
return hostInfo.RoutableIPs, nil
}
-func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
- data, err := m.EnabledRoutes.MarshalJSON()
+func (machine *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
+ data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@@ -654,13 +705,13 @@ func (m *Machine) GetEnabledRoutes() ([]netaddr.IPPrefix, error) {
return routes, nil
}
-func (m *Machine) IsRoutesEnabled(routeStr string) bool {
+func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
- enabledRoutes, err := m.GetEnabledRoutes()
+ enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return false
}
@@ -670,12 +721,13 @@ func (m *Machine) IsRoutesEnabled(routeStr string) bool {
return true
}
}
+
return false
}
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
// previous list of routes.
-func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
+func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netaddr.ParseIPPrefix(routeStr)
@@ -686,14 +738,18 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
newRoutes[index] = route
}
- availableRoutes, err := m.GetAdvertisedRoutes()
+ availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return err
}
for _, newRoute := range newRoutes {
- if !containsIpPrefix(availableRoutes, newRoute) {
- return fmt.Errorf("route (%s) is not available on node %s", m.Name, newRoute)
+ if !containsIPPrefix(availableRoutes, newRoute) {
+ return fmt.Errorf(
+ "route (%s) is not available on node %s: %w",
+ machine.Name,
+ newRoute, errMachineRouteIsNotAvailable,
+ )
}
}
@@ -702,10 +758,10 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return err
}
- m.EnabledRoutes = datatypes.JSON(routes)
- h.db.Save(&m)
+ machine.EnabledRoutes = datatypes.JSON(routes)
+ h.db.Save(&machine)
- err = h.RequestMapUpdates(m.NamespaceID)
+ err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil {
return err
}
@@ -713,13 +769,13 @@ func (h *Headscale) EnableRoutes(m *Machine, routeStrs ...string) error {
return nil
}
-func (m *Machine) RoutesToProto() (*v1.Routes, error) {
- availableRoutes, err := m.GetAdvertisedRoutes()
+func (machine *Machine) RoutesToProto() (*v1.Routes, error) {
+ availableRoutes, err := machine.GetAdvertisedRoutes()
if err != nil {
return nil, err
}
- enabledRoutes, err := m.GetEnabledRoutes()
+ enabledRoutes, err := machine.GetEnabledRoutes()
if err != nil {
return nil, err
}
diff --git a/machine_test.go b/machine_test.go
index dfe84d3..cf36740 100644
--- a/machine_test.go
+++ b/machine_test.go
@@ -8,152 +8,159 @@ import (
)
func (s *Suite) TestGetMachine(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("test", "testmachine")
+ _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
- m := &Machine{
+ machine := &Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
- h.db.Save(m)
+ app.db.Save(machine)
- m1, err := h.GetMachine("test", "testmachine")
+ machineFromDB, err := app.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil)
- _, err = m1.GetHostInfo()
+ _, err = machineFromDB.GetHostInfo()
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetMachineByID(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachineByID(0)
+ _, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- m1, err := h.GetMachineByID(0)
+ machineByID, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil)
- _, err = m1.GetHostInfo()
+ _, err = machineByID.GetHostInfo()
c.Assert(err, check.IsNil)
}
func (s *Suite) TestDeleteMachine(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(1),
}
- h.db.Save(&m)
- err = h.DeleteMachine(&m)
+ app.db.Save(&machine)
+
+ err = app.DeleteMachine(&machine)
c.Assert(err, check.IsNil)
- v, err := h.getValue("namespaces_pending_updates")
+
+ namespacesPendingUpdates, err := app.getValue("namespaces_pending_updates")
c.Assert(err, check.IsNil)
+
names := []string{}
- err = json.Unmarshal([]byte(v), &names)
+ err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
c.Assert(err, check.IsNil)
- c.Assert(names, check.DeepEquals, []string{n.Name})
- h.checkForNamespacesPendingUpdates()
- v, _ = h.getValue("namespaces_pending_updates")
- c.Assert(v, check.Equals, "")
- _, err = h.GetMachine(n.Name, "testmachine")
+ c.Assert(names, check.DeepEquals, []string{namespace.Name})
+
+ app.checkForNamespacesPendingUpdates()
+
+ namespacesPendingUpdates, _ = app.getValue("namespaces_pending_updates")
+ c.Assert(namespacesPendingUpdates, check.Equals, "")
+ _, err = app.GetMachine(namespace.Name, "testmachine")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestHardDeleteMachine(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine3",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(1),
}
- h.db.Save(&m)
- err = h.HardDeleteMachine(&m)
+ app.db.Save(&machine)
+
+ err = app.HardDeleteMachine(&machine)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine(n.Name, "testmachine3")
+
+ _, err = app.GetMachine(namespace.Name, "testmachine3")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestGetDirectPeers(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachineByID(0)
+ _, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
- for i := 0; i <= 10; i++ {
- m := Machine{
- ID: uint64(i),
- MachineKey: "foo" + strconv.Itoa(i),
- NodeKey: "bar" + strconv.Itoa(i),
- DiscoKey: "faa" + strconv.Itoa(i),
- Name: "testmachine" + strconv.Itoa(i),
- NamespaceID: n.ID,
+ for index := 0; index <= 10; index++ {
+ machine := Machine{
+ ID: uint64(index),
+ MachineKey: "foo" + strconv.Itoa(index),
+ NodeKey: "bar" + strconv.Itoa(index),
+ DiscoKey: "faa" + strconv.Itoa(index),
+ Name: "testmachine" + strconv.Itoa(index),
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
}
- m1, err := h.GetMachineByID(0)
+ machine0ByID, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil)
- _, err = m1.GetHostInfo()
+ _, err = machine0ByID.GetHostInfo()
c.Assert(err, check.IsNil)
- peers, err := h.getDirectPeers(m1)
+ peersOfMachine0, err := app.getDirectPeers(machine0ByID)
c.Assert(err, check.IsNil)
- c.Assert(len(peers), check.Equals, 9)
- c.Assert(peers[0].Name, check.Equals, "testmachine2")
- c.Assert(peers[5].Name, check.Equals, "testmachine7")
- c.Assert(peers[8].Name, check.Equals, "testmachine10")
+ c.Assert(len(peersOfMachine0), check.Equals, 9)
+ c.Assert(peersOfMachine0[0].Name, check.Equals, "testmachine2")
+ c.Assert(peersOfMachine0[5].Name, check.Equals, "testmachine7")
+ c.Assert(peersOfMachine0[8].Name, check.Equals, "testmachine10")
}
diff --git a/metrics.go b/metrics.go
index 0d3dca3..f0ce16e 100644
--- a/metrics.go
+++ b/metrics.go
@@ -32,7 +32,7 @@ var (
Name: "update_request_sent_to_node_total",
Help: "The number of calls/messages issued on a specific nodes update channel",
}, []string{"namespace", "machine", "status"})
- //TODO(kradalby): This is very debugging, we might want to remove it.
+ // TODO(kradalby): This is very debugging, we might want to remove it.
updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "update_request_received_on_channel_total",
diff --git a/namespaces.go b/namespaces.go
index e5d1783..e512068 100644
--- a/namespaces.go
+++ b/namespaces.go
@@ -15,9 +15,9 @@ import (
)
const (
- errorNamespaceExists = Error("Namespace already exists")
- errorNamespaceNotFound = Error("Namespace not found")
- errorNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
+ errNamespaceExists = Error("Namespace already exists")
+ errNamespaceNotFound = Error("Namespace not found")
+ errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
)
// Namespace is the way Headscale implements the concept of users in Tailscale
@@ -30,51 +30,53 @@ type Namespace struct {
}
// CreateNamespace creates a new Namespace. Returns error if could not be created
-// or another namespace already exists
+// or another namespace already exists.
func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
- n := Namespace{}
- if err := h.db.Where("name = ?", name).First(&n).Error; err == nil {
- return nil, errorNamespaceExists
+ namespace := Namespace{}
+ if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
+ return nil, errNamespaceExists
}
- n.Name = name
- if err := h.db.Create(&n).Error; err != nil {
+ namespace.Name = name
+ if err := h.db.Create(&namespace).Error; err != nil {
log.Error().
Str("func", "CreateNamespace").
Err(err).
Msg("Could not create row")
+
return nil, err
}
- return &n, nil
+
+ return &namespace, nil
}
// DestroyNamespace destroys a Namespace. Returns error if the Namespace does
// not exist or if there are machines associated with it.
func (h *Headscale) DestroyNamespace(name string) error {
- n, err := h.GetNamespace(name)
+ namespace, err := h.GetNamespace(name)
if err != nil {
- return errorNamespaceNotFound
+ return errNamespaceNotFound
}
- m, err := h.ListMachinesInNamespace(name)
+ machines, err := h.ListMachinesInNamespace(name)
if err != nil {
return err
}
- if len(m) > 0 {
- return errorNamespaceNotEmptyOfNodes
+ if len(machines) > 0 {
+ return errNamespaceNotEmptyOfNodes
}
keys, err := h.ListPreAuthKeys(name)
if err != nil {
return err
}
- for _, p := range keys {
- err = h.DestroyPreAuthKey(&p)
+ for _, key := range keys {
+ err = h.DestroyPreAuthKey(key)
if err != nil {
return err
}
}
- if result := h.db.Unscoped().Delete(&n); result.Error != nil {
+ if result := h.db.Unscoped().Delete(&namespace); result.Error != nil {
return result.Error
}
@@ -84,25 +86,25 @@ func (h *Headscale) DestroyNamespace(name string) error {
// RenameNamespace renames a Namespace. Returns error if the Namespace does
// not exist or if another Namespace exists with the new name.
func (h *Headscale) RenameNamespace(oldName, newName string) error {
- n, err := h.GetNamespace(oldName)
+ oldNamespace, err := h.GetNamespace(oldName)
if err != nil {
return err
}
_, err = h.GetNamespace(newName)
if err == nil {
- return errorNamespaceExists
+ return errNamespaceExists
}
- if !errors.Is(err, errorNamespaceNotFound) {
+ if !errors.Is(err, errNamespaceNotFound) {
return err
}
- n.Name = newName
+ oldNamespace.Name = newName
- if result := h.db.Save(&n); result.Error != nil {
+ if result := h.db.Save(&oldNamespace); result.Error != nil {
return result.Error
}
- err = h.RequestMapUpdates(n.ID)
+ err = h.RequestMapUpdates(oldNamespace.ID)
if err != nil {
return err
}
@@ -110,39 +112,45 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
return nil
}
-// GetNamespace fetches a namespace by name
+// GetNamespace fetches a namespace by name.
func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
- n := Namespace{}
- if result := h.db.First(&n, "name = ?", name); errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, errorNamespaceNotFound
+ namespace := Namespace{}
+ if result := h.db.First(&namespace, "name = ?", name); errors.Is(
+ result.Error,
+ gorm.ErrRecordNotFound,
+ ) {
+ return nil, errNamespaceNotFound
}
- return &n, nil
+
+ return &namespace, nil
}
-// ListNamespaces gets all the existing namespaces
+// ListNamespaces gets all the existing namespaces.
func (h *Headscale) ListNamespaces() ([]Namespace, error) {
namespaces := []Namespace{}
if err := h.db.Find(&namespaces).Error; err != nil {
return nil, err
}
+
return namespaces, nil
}
-// ListMachinesInNamespace gets all the nodes in a given namespace
+// ListMachinesInNamespace gets all the nodes in a given namespace.
func (h *Headscale) ListMachinesInNamespace(name string) ([]Machine, error) {
- n, err := h.GetNamespace(name)
+ namespace, err := h.GetNamespace(name)
if err != nil {
return nil, err
}
machines := []Machine{}
- if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
+ if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: namespace.ID}).Find(&machines).Error; err != nil {
return nil, err
}
+
return machines, nil
}
-// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace
+// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace.
func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error) {
namespace, err := h.GetNamespace(name)
if err != nil {
@@ -155,48 +163,61 @@ func (h *Headscale) ListSharedMachinesInNamespace(name string) ([]Machine, error
machines := []Machine{}
for _, sharedMachine := range sharedMachines {
- machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled
+ machine, err := h.GetMachineByID(
+ sharedMachine.MachineID,
+ ) // otherwise not everything comes filled
if err != nil {
return nil, err
}
machines = append(machines, *machine)
}
+
return machines, nil
}
-// SetMachineNamespace assigns a Machine to a namespace
-func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
- n, err := h.GetNamespace(namespaceName)
+// SetMachineNamespace assigns a Machine to a namespace.
+func (h *Headscale) SetMachineNamespace(machine *Machine, namespaceName string) error {
+ namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return err
}
- m.NamespaceID = n.ID
- h.db.Save(&m)
+ machine.NamespaceID = namespace.ID
+ h.db.Save(&machine)
+
return nil
}
-// RequestMapUpdates signals the KV worker to update the maps for this namespace
+// TODO(kradalby): Remove the need for this.
+// RequestMapUpdates signals the KV worker to update the maps for this namespace.
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{}
if err := h.db.First(&namespace, namespaceID).Error; err != nil {
return err
}
- v, err := h.getValue("namespaces_pending_updates")
- if err != nil || v == "" {
- err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
+ namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
+ if err != nil || namespacesPendingUpdates == "" {
+ err = h.setValue(
+ "namespaces_pending_updates",
+ fmt.Sprintf(`["%s"]`, namespace.Name),
+ )
if err != nil {
return err
}
+
return nil
}
names := []string{}
- err = json.Unmarshal([]byte(v), &names)
+ err = json.Unmarshal([]byte(namespacesPendingUpdates), &names)
if err != nil {
- err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
+ err = h.setValue(
+ "namespaces_pending_updates",
+ fmt.Sprintf(`["%s"]`, namespace.Name),
+ )
if err != nil {
return err
}
+
return nil
}
@@ -207,22 +228,24 @@ func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
Str("func", "RequestMapUpdates").
Err(err).
Msg("Could not marshal namespaces_pending_updates")
+
return err
}
+
return h.setValue("namespaces_pending_updates", string(data))
}
func (h *Headscale) checkForNamespacesPendingUpdates() {
- v, err := h.getValue("namespaces_pending_updates")
+ namespacesPendingUpdates, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
- if v == "" {
+ if namespacesPendingUpdates == "" {
return
}
namespaces := []string{}
- err = json.Unmarshal([]byte(v), &namespaces)
+ err = json.Unmarshal([]byte(namespacesPendingUpdates), &namespaces)
if err != nil {
return
}
@@ -233,24 +256,25 @@ func (h *Headscale) checkForNamespacesPendingUpdates() {
Msg("Sending updates to nodes in namespacespace")
h.setLastStateChangeToNow(namespace)
}
- newV, err := h.getValue("namespaces_pending_updates")
+ newPendingUpdateValue, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
- if v == newV { // only clear when no changes, so we notified everybody
+ if namespacesPendingUpdates == newPendingUpdateValue { // only clear when no changes, so we notified everybody
err = h.setValue("namespaces_pending_updates", "")
if err != nil {
log.Error().
Str("func", "checkForNamespacesPendingUpdates").
Err(err).
Msg("Could not save to KV")
+
return
}
}
}
func (n *Namespace) toUser() *tailcfg.User {
- u := tailcfg.User{
+ user := tailcfg.User{
ID: tailcfg.UserID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
@@ -259,25 +283,27 @@ func (n *Namespace) toUser() *tailcfg.User {
Logins: []tailcfg.LoginID{},
Created: time.Time{},
}
- return &u
+
+ return &user
}
func (n *Namespace) toLogin() *tailcfg.Login {
- l := tailcfg.Login{
+ login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
}
- return &l
+
+ return &login
}
-func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
+func getMapResponseUserProfiles(machine Machine, peers Machines) []tailcfg.UserProfile {
namespaceMap := make(map[string]Namespace)
- namespaceMap[m.Namespace.Name] = m.Namespace
- for _, p := range peers {
- namespaceMap[p.Namespace.Name] = p.Namespace // not worth checking if already is there
+ namespaceMap[machine.Namespace.Name] = machine.Namespace
+ for _, peer := range peers {
+ namespaceMap[peer.Namespace.Name] = peer.Namespace // not worth checking if already is there
}
profiles := []tailcfg.UserProfile{}
@@ -289,12 +315,13 @@ func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile
DisplayName: namespace.Name,
})
}
+
return profiles
}
func (n *Namespace) toProto() *v1.Namespace {
return &v1.Namespace{
- Id: strconv.FormatUint(uint64(n.ID), 10),
+ Id: strconv.FormatUint(uint64(n.ID), Base10),
Name: n.Name,
CreatedAt: timestamppb.New(n.CreatedAt),
}
diff --git a/namespaces_test.go b/namespaces_test.go
index 72193ee..bbae98f 100644
--- a/namespaces_test.go
+++ b/namespaces_test.go
@@ -7,207 +7,232 @@ import (
)
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- c.Assert(n.Name, check.Equals, "test")
+ c.Assert(namespace.Name, check.Equals, "test")
- ns, err := h.ListNamespaces()
+ namespaces, err := app.ListNamespaces()
c.Assert(err, check.IsNil)
- c.Assert(len(ns), check.Equals, 1)
+ c.Assert(len(namespaces), check.Equals, 1)
- err = h.DestroyNamespace("test")
+ err = app.DestroyNamespace("test")
c.Assert(err, check.IsNil)
- _, err = h.GetNamespace("test")
+ _, err = app.GetNamespace("test")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
- err := h.DestroyNamespace("test")
- c.Assert(err, check.Equals, errorNamespaceNotFound)
+ err := app.DestroyNamespace("test")
+ c.Assert(err, check.Equals, errNamespaceNotFound)
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- err = h.DestroyNamespace("test")
+ err = app.DestroyNamespace("test")
c.Assert(err, check.IsNil)
- result := h.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
+ result := app.db.Preload("Namespace").First(&pak, "key = ?", pak.Key)
// destroying a namespace also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
- n, err = h.CreateNamespace("test")
+ namespace, err = app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- pak, err = h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err = app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- err = h.DestroyNamespace("test")
- c.Assert(err, check.Equals, errorNamespaceNotEmptyOfNodes)
+ err = app.DestroyNamespace("test")
+ c.Assert(err, check.Equals, errNamespaceNotEmptyOfNodes)
}
func (s *Suite) TestRenameNamespace(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespaceTest, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- c.Assert(n.Name, check.Equals, "test")
+ c.Assert(namespaceTest.Name, check.Equals, "test")
- ns, err := h.ListNamespaces()
+ namespaces, err := app.ListNamespaces()
c.Assert(err, check.IsNil)
- c.Assert(len(ns), check.Equals, 1)
+ c.Assert(len(namespaces), check.Equals, 1)
- err = h.RenameNamespace("test", "test_renamed")
+ err = app.RenameNamespace("test", "test_renamed")
c.Assert(err, check.IsNil)
- _, err = h.GetNamespace("test")
- c.Assert(err, check.Equals, errorNamespaceNotFound)
+ _, err = app.GetNamespace("test")
+ c.Assert(err, check.Equals, errNamespaceNotFound)
- _, err = h.GetNamespace("test_renamed")
+ _, err = app.GetNamespace("test_renamed")
c.Assert(err, check.IsNil)
- err = h.RenameNamespace("test_does_not_exit", "test")
- c.Assert(err, check.Equals, errorNamespaceNotFound)
+ err = app.RenameNamespace("test_does_not_exit", "test")
+ c.Assert(err, check.Equals, errNamespaceNotFound)
- n2, err := h.CreateNamespace("test2")
+ namespaceTest2, err := app.CreateNamespace("test2")
c.Assert(err, check.IsNil)
- c.Assert(n2.Name, check.Equals, "test2")
+ c.Assert(namespaceTest2.Name, check.Equals, "test2")
- err = h.RenameNamespace("test2", "test_renamed")
- c.Assert(err, check.Equals, errorNamespaceExists)
+ err = app.RenameNamespace("test2", "test_renamed")
+ c.Assert(err, check.Equals, errNamespaceExists)
}
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
- n1, err := h.CreateNamespace("shared1")
+ namespaceShared1, err := app.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
- n2, err := h.CreateNamespace("shared2")
+ namespaceShared2, err := app.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
- n3, err := h.CreateNamespace("shared3")
+ namespaceShared3, err := app.CreateNamespace("shared3")
c.Assert(err, check.IsNil)
- pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
+ preAuthKeyShared1, err := app.CreatePreAuthKey(
+ namespaceShared1.Name,
+ false,
+ false,
+ nil,
+ )
c.Assert(err, check.IsNil)
- pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
+ preAuthKeyShared2, err := app.CreatePreAuthKey(
+ namespaceShared2.Name,
+ false,
+ false,
+ nil,
+ )
c.Assert(err, check.IsNil)
- pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
+ preAuthKeyShared3, err := app.CreatePreAuthKey(
+ namespaceShared3.Name,
+ false,
+ false,
+ nil,
+ )
c.Assert(err, check.IsNil)
- pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
+ preAuthKey2Shared1, err := app.CreatePreAuthKey(
+ namespaceShared1.Name,
+ false,
+ false,
+ nil,
+ )
c.Assert(err, check.IsNil)
- _, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
+ _, err = app.GetMachine(namespaceShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
- m1 := &Machine{
+ machineInShared1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
- NamespaceID: n1.ID,
- Namespace: *n1,
+ NamespaceID: namespaceShared1.ID,
+ Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
- AuthKeyID: uint(pak1n1.ID),
+ AuthKeyID: uint(preAuthKeyShared1.ID),
}
- h.db.Save(m1)
+ app.db.Save(machineInShared1)
- _, err = h.GetMachine(n1.Name, m1.Name)
+ _, err = app.GetMachine(namespaceShared1.Name, machineInShared1.Name)
c.Assert(err, check.IsNil)
- m2 := &Machine{
+ machineInShared2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
- NamespaceID: n2.ID,
- Namespace: *n2,
+ NamespaceID: namespaceShared2.ID,
+ Namespace: *namespaceShared2,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
- AuthKeyID: uint(pak2n2.ID),
+ AuthKeyID: uint(preAuthKeyShared2.ID),
}
- h.db.Save(m2)
+ app.db.Save(machineInShared2)
- _, err = h.GetMachine(n2.Name, m2.Name)
+ _, err = app.GetMachine(namespaceShared2.Name, machineInShared2.Name)
c.Assert(err, check.IsNil)
- m3 := &Machine{
+ machineInShared3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3",
- NamespaceID: n3.ID,
- Namespace: *n3,
+ NamespaceID: namespaceShared3.ID,
+ Namespace: *namespaceShared3,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.3",
- AuthKeyID: uint(pak3n3.ID),
+ AuthKeyID: uint(preAuthKeyShared3.ID),
}
- h.db.Save(m3)
+ app.db.Save(machineInShared3)
- _, err = h.GetMachine(n3.Name, m3.Name)
+ _, err = app.GetMachine(namespaceShared3.Name, machineInShared3.Name)
c.Assert(err, check.IsNil)
- m4 := &Machine{
+ machine2InShared1 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4",
- NamespaceID: n1.ID,
- Namespace: *n1,
+ NamespaceID: namespaceShared1.ID,
+ Namespace: *namespaceShared1,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
- AuthKeyID: uint(pak4n1.ID),
+ AuthKeyID: uint(preAuthKey2Shared1.ID),
}
- h.db.Save(m4)
+ app.db.Save(machine2InShared1)
- err = h.AddSharedMachineToNamespace(m2, n1)
+ err = app.AddSharedMachineToNamespace(machineInShared2, namespaceShared1)
c.Assert(err, check.IsNil)
- m1peers, err := h.getPeers(m1)
+ peersOfMachine1InShared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil)
- userProfiles := getMapResponseUserProfiles(*m1, m1peers)
+ userProfiles := getMapResponseUserProfiles(
+ *machineInShared1,
+ peersOfMachine1InShared1,
+ )
log.Trace().Msgf("userProfiles %#v", userProfiles)
c.Assert(len(userProfiles), check.Equals, 2)
found := false
- for _, up := range userProfiles {
- if up.DisplayName == n1.Name {
+ for _, userProfiles := range userProfiles {
+ if userProfiles.DisplayName == namespaceShared1.Name {
found = true
+
break
}
}
c.Assert(found, check.Equals, true)
found = false
- for _, up := range userProfiles {
- if up.DisplayName == n2.Name {
+ for _, userProfile := range userProfiles {
+ if userProfile.DisplayName == namespaceShared2.Name {
found = true
+
break
}
}
diff --git a/oidc.go b/oidc.go
index 51c443d..07561e8 100644
--- a/oidc.go
+++ b/oidc.go
@@ -17,6 +17,12 @@ import (
"golang.org/x/oauth2"
)
+const (
+ oidcStateCacheExpiration = time.Minute * 5
+ oidcStateCacheCleanupInterval = time.Minute * 10
+ randomByteSize = 16
+)
+
type IDTokenClaims struct {
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
@@ -32,6 +38,7 @@ func (h *Headscale) initOIDC() error {
if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
+
return err
}
@@ -39,14 +46,20 @@ func (h *Headscale) initOIDC() error {
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
- RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
- Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
+ RedirectURL: fmt.Sprintf(
+ "%s/oidc/callback",
+ strings.TrimSuffix(h.cfg.ServerURL, "/"),
+ ),
+ Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
// init the state cache if it hasn't been already
if h.oidcStateCache == nil {
- h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
+ h.oidcStateCache = cache.New(
+ oidcStateCacheExpiration,
+ oidcStateCacheCleanupInterval,
+ )
}
return nil
@@ -54,50 +67,53 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
-// Listens in /oidc/register/:mKey
-func (h *Headscale) RegisterOIDC(c *gin.Context) {
- mKeyStr := c.Param("mkey")
+// Listens in /oidc/register/:mKey.
+func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
+ mKeyStr := ctx.Param("mkey")
if mKeyStr == "" {
- c.String(http.StatusBadRequest, "Wrong params")
+ ctx.String(http.StatusBadRequest, "Wrong params")
+
return
}
- b := make([]byte, 16)
- _, err := rand.Read(b)
- if err != nil {
+ randomBlob := make([]byte, randomByteSize)
+ if _, err := rand.Read(randomBlob); err != nil {
log.Error().Msg("could not read 16 bytes from rand")
- c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
+ ctx.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
+
return
}
- stateStr := hex.EncodeToString(b)[:32]
+ stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later
- h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
+ h.oidcStateCache.Set(stateStr, mKeyStr, oidcStateCacheExpiration)
- authUrl := h.oauth2Config.AuthCodeURL(stateStr)
- log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
+ authURL := h.oauth2Config.AuthCodeURL(stateStr)
+ log.Debug().Msgf("Redirecting to %s for authentication", authURL)
- c.Redirect(http.StatusFound, authUrl)
+ ctx.Redirect(http.StatusFound, authURL)
}
// OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
-// Listens in /oidc/callback
-func (h *Headscale) OIDCCallback(c *gin.Context) {
- code := c.Query("code")
- state := c.Query("state")
+// Listens in /oidc/callback.
+func (h *Headscale) OIDCCallback(ctx *gin.Context) {
+ code := ctx.Query("code")
+ state := ctx.Query("state")
if code == "" || state == "" {
- c.String(http.StatusBadRequest, "Wrong params")
+ ctx.String(http.StatusBadRequest, "Wrong params")
+
return
}
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil {
- c.String(http.StatusBadRequest, "Could not exchange code for token")
+ ctx.String(http.StatusBadRequest, "Could not exchange code for token")
+
return
}
@@ -105,7 +121,8 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
- c.String(http.StatusBadRequest, "Could not extract ID Token")
+ ctx.String(http.StatusBadRequest, "Could not extract ID Token")
+
return
}
@@ -113,21 +130,26 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
- c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
+ ctx.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
+
return
}
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
- //userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
- //if err != nil {
- // c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err))
- // return
- //}
+ // userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
+ // if err != nil {
+ // c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err))
+ // return
+ // }
// Extract custom claims
var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
- c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err))
+ ctx.String(
+ http.StatusBadRequest,
+ fmt.Sprintf("Failed to decode id token claims: %s", err),
+ )
+
return
}
@@ -135,62 +157,80 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound {
- log.Error().Msg("requested machine state key expired before authorisation completed")
- c.String(http.StatusBadRequest, "state has expired")
+ log.Error().
+ Msg("requested machine state key expired before authorisation completed")
+ ctx.String(http.StatusBadRequest, "state has expired")
+
return
}
mKeyStr, mKeyOK := mKeyIf.(string)
if !mKeyOK {
log.Error().Msg("could not get machine key from cache")
- c.String(http.StatusInternalServerError, "could not get machine key from cache")
+ ctx.String(
+ http.StatusInternalServerError,
+ "could not get machine key from cache",
+ )
+
return
}
// retrieve machine information
- m, err := h.GetMachineByMachineKey(mKeyStr)
+ machine, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil {
log.Error().Msg("machine key not found in database")
- c.String(http.StatusInternalServerError, "could not get machine info from database")
+ ctx.String(
+ http.StatusInternalServerError,
+ "could not get machine info from database",
+ )
+
return
}
now := time.Now().UTC()
- if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
+ if namespaceName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new
- if !m.Registered {
-
+ if !machine.Registered {
log.Debug().Msg("Registering new machine after successful callback")
- ns, err := h.GetNamespace(nsName)
+ namespace, err := h.GetNamespace(namespaceName)
if err != nil {
- ns, err = h.CreateNamespace(nsName)
+ namespace, err = h.CreateNamespace(namespaceName)
if err != nil {
- log.Error().Msgf("could not create new namespace '%s'", claims.Email)
- c.String(http.StatusInternalServerError, "could not create new namespace")
+ log.Error().
+ Msgf("could not create new namespace '%s'", claims.Email)
+ ctx.String(
+ http.StatusInternalServerError,
+ "could not create new namespace",
+ )
+
return
}
}
ip, err := h.getAvailableIP()
if err != nil {
- c.String(http.StatusInternalServerError, "could not get an IP from the pool")
+ ctx.String(
+ http.StatusInternalServerError,
+ "could not get an IP from the pool",
+ )
+
return
}
- m.IPAddress = ip.String()
- m.NamespaceID = ns.ID
- m.Registered = true
- m.RegisterMethod = "oidc"
- m.LastSuccessfulUpdate = &now
- h.db.Save(&m)
+ machine.IPAddress = ip.String()
+ machine.NamespaceID = namespace.ID
+ machine.Registered = true
+ machine.RegisterMethod = "oidc"
+ machine.LastSuccessfulUpdate = &now
+ h.db.Save(&machine)
}
- h.updateMachineExpiry(m)
+ h.updateMachineExpiry(machine)
- c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
+ ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
headscale
@@ -201,15 +241,17 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
`, claims.Email)))
-
}
log.Error().
Str("email", claims.Email).
Str("username", claims.Username).
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Email could not be mapped to a namespace")
- c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
+ ctx.String(
+ http.StatusBadRequest,
+ "email from claim could not be mapped to a namespace",
+ )
}
// getNamespaceFromEmail passes the users email through a list of "matchers"
diff --git a/oidc_test.go b/oidc_test.go
index c7a29ce..21a4357 100644
--- a/oidc_test.go
+++ b/oidc_test.go
@@ -145,29 +145,37 @@ func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
},
}
//nolint
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- h := &Headscale{
- cfg: tt.fields.cfg,
- db: tt.fields.db,
- dbString: tt.fields.dbString,
- dbType: tt.fields.dbType,
- dbDebug: tt.fields.dbDebug,
- publicKey: tt.fields.publicKey,
- privateKey: tt.fields.privateKey,
- aclPolicy: tt.fields.aclPolicy,
- aclRules: tt.fields.aclRules,
- lastStateChange: tt.fields.lastStateChange,
- oidcProvider: tt.fields.oidcProvider,
- oauth2Config: tt.fields.oauth2Config,
- oidcStateCache: tt.fields.oidcStateCache,
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ app := &Headscale{
+ cfg: test.fields.cfg,
+ db: test.fields.db,
+ dbString: test.fields.dbString,
+ dbType: test.fields.dbType,
+ dbDebug: test.fields.dbDebug,
+ publicKey: test.fields.publicKey,
+ privateKey: test.fields.privateKey,
+ aclPolicy: test.fields.aclPolicy,
+ aclRules: test.fields.aclRules,
+ lastStateChange: test.fields.lastStateChange,
+ oidcProvider: test.fields.oidcProvider,
+ oauth2Config: test.fields.oauth2Config,
+ oidcStateCache: test.fields.oidcStateCache,
}
- got, got1 := h.getNamespaceFromEmail(tt.args.email)
- if got != tt.want {
- t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
+ got, got1 := app.getNamespaceFromEmail(test.args.email)
+ if got != test.want {
+ t.Errorf(
+ "Headscale.getNamespaceFromEmail() got = %v, want %v",
+ got,
+ test.want,
+ )
}
- if got1 != tt.want1 {
- t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
+ if got1 != test.want1 {
+ t.Errorf(
+ "Headscale.getNamespaceFromEmail() got1 = %v, want %v",
+ got1,
+ test.want1,
+ )
}
})
}
diff --git a/poll.go b/poll.go
index 6a65280..9cf14e7 100644
--- a/poll.go
+++ b/poll.go
@@ -15,6 +15,11 @@ import (
"tailscale.com/types/wgkey"
)
+const (
+ keepAliveInterval = 60 * time.Second
+ updateCheckInterval = 10 * time.Second
+)
+
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
@@ -24,20 +29,21 @@ import (
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
-func (h *Headscale) PollNetMapHandler(c *gin.Context) {
+func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
- Str("id", c.Param("id")).
+ Str("id", ctx.Param("id")).
Msg("PollNetMapHandler called")
- body, _ := io.ReadAll(c.Request.Body)
- mKeyStr := c.Param("id")
+ body, _ := io.ReadAll(ctx.Request.Body)
+ mKeyStr := ctx.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
- c.String(http.StatusBadRequest, "")
+ ctx.String(http.StatusBadRequest, "")
+
return
}
req := tailcfg.MapRequest{}
@@ -47,34 +53,36 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
- c.String(http.StatusBadRequest, "")
+ ctx.String(http.StatusBadRequest, "")
+
return
}
- m, err := h.GetMachineByMachineKey(mKey.HexString())
+ machine, err := h.GetMachineByMachineKey(mKey.HexString())
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
- c.String(http.StatusUnauthorized, "")
+ ctx.String(http.StatusUnauthorized, "")
+
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", mKey.HexString())
- c.String(http.StatusInternalServerError, "")
+ ctx.String(http.StatusInternalServerError, "")
}
log.Trace().
Str("handler", "PollNetMap").
- Str("id", c.Param("id")).
- Str("machine", m.Name).
+ Str("id", ctx.Param("id")).
+ Str("machine", machine.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
- m.Name = req.Hostinfo.Hostname
- m.HostInfo = datatypes.JSON(hostinfo)
- m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
+ machine.Name = req.Hostinfo.Hostname
+ machine.HostInfo = datatypes.JSON(hostinfo)
+ machine.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
@@ -87,20 +95,21 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
- m.Endpoints = datatypes.JSON(endpoints)
- m.LastSeen = &now
+ machine.Endpoints = datatypes.JSON(endpoints)
+ machine.LastSeen = &now
}
- h.db.Save(&m)
+ h.db.Save(&machine)
- data, err := h.getMapResponse(mKey, req, m)
+ data, err := h.getMapResponse(mKey, req, machine)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
- Str("id", c.Param("id")).
- Str("machine", m.Name).
+ Str("id", ctx.Param("id")).
+ Str("machine", machine.Name).
Err(err).
Msg("Failed to get Map response")
- c.String(http.StatusInternalServerError, ":(")
+ ctx.String(http.StatusInternalServerError, ":(")
+
return
}
@@ -111,8 +120,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
- Str("id", c.Param("id")).
- Str("machine", m.Name).
+ Str("id", ctx.Param("id")).
+ Str("machine", machine.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
@@ -121,15 +130,16 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Client is starting up. Probably interested in a DERP map")
- c.Data(200, "application/json; charset=utf-8", data)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
+
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
- h.setLastStateChangeToNow(m.Namespace.Name)
+ h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
@@ -137,8 +147,8 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// Only create update channel if it has not been created
log.Trace().
Str("handler", "PollNetMap").
- Str("id", c.Param("id")).
- Str("machine", m.Name).
+ Str("id", ctx.Param("id")).
+ Str("machine", machine.Name).
Msg("Loading or creating update channel")
updateChan := make(chan struct{})
@@ -152,46 +162,59 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
- c.Data(200, "application/json; charset=utf-8", data)
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
- updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "endpoint-update").Inc()
+ updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "endpoint-update").
+ Inc()
go func() { updateChan <- struct{}{} }()
+
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Ignoring request, don't know how to handle it")
- c.String(http.StatusBadRequest, "")
+ ctx.String(http.StatusBadRequest, "")
+
return
}
log.Info().
Str("handler", "PollNetMap").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Sending initial map")
go func() { pollDataChan <- data }()
log.Info().
Str("handler", "PollNetMap").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Notifying peers")
- updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "full-update").Inc()
+ updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "full-update").
+ Inc()
go func() { updateChan <- struct{}{} }()
- h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
+ h.PollNetMapStream(
+ ctx,
+ machine,
+ req,
+ mKey,
+ pollDataChan,
+ keepAliveChan,
+ updateChan,
+ cancelKeepAlive,
+ )
log.Trace().
Str("handler", "PollNetMap").
- Str("id", c.Param("id")).
- Str("machine", m.Name).
+ Str("id", ctx.Param("id")).
+ Str("machine", machine.Name).
Msg("Finished stream, closing PollNetMap session")
}
@@ -199,165 +222,181 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream(
- c *gin.Context,
- m *Machine,
- req tailcfg.MapRequest,
- mKey wgkey.Key,
+ ctx *gin.Context,
+ machine *Machine,
+ mapRequest tailcfg.MapRequest,
+ machineKey wgkey.Key,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
cancelKeepAlive chan struct{},
) {
- go h.scheduledPollWorker(cancelKeepAlive, updateChan, keepAliveChan, mKey, req, m)
+ go h.scheduledPollWorker(
+ cancelKeepAlive,
+ updateChan,
+ keepAliveChan,
+ machineKey,
+ mapRequest,
+ machine,
+ )
- c.Stream(func(w io.Writer) bool {
+ ctx.Stream(func(writer io.Writer) bool {
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Waiting for data to stream...")
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select {
case data := <-pollDataChan:
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
- _, err := w.Write(data)
+ _, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
+
return false
}
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
- err = h.UpdateMachine(m)
+ err = h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
- m.LastSeen = &now
+ machine.LastSeen = &now
- lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
- m.LastSuccessfulUpdate = &now
+ lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
+ Set(float64(now.Unix()))
+ machine.LastSuccessfulUpdate = &now
- h.db.Save(&m)
+ h.db.Save(&machine)
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending pollData")
+
return true
case data := <-keepAliveChan:
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
- _, err := w.Write(data)
+ _, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
+
return false
}
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
- err = h.UpdateMachine(m)
+ err = h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
- m.LastSeen = &now
- h.db.Save(&m)
+ machine.LastSeen = &now
+ h.db.Save(&machine)
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive")
+
return true
case <-updateChan:
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "update").
Msg("Received a request for update")
- updateRequestsReceivedOnChannel.WithLabelValues(m.Name, m.Namespace.Name).Inc()
- if h.isOutdated(m) {
+ updateRequestsReceivedOnChannel.WithLabelValues(machine.Name, machine.Namespace.Name).
+ Inc()
+ if h.isOutdated(machine) {
log.Debug().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
- Time("last_successful_update", *m.LastSuccessfulUpdate).
- Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
- Msgf("There has been updates since the last successful update to %s", m.Name)
- data, err := h.getMapResponse(mKey, req, m)
+ Str("machine", machine.Name).
+ Time("last_successful_update", *machine.LastSuccessfulUpdate).
+ Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
+ Msgf("There has been updates since the last successful update to %s", machine.Name)
+ data, err := h.getMapResponse(machineKey, mapRequest, machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
}
- _, err = w.Write(data)
+ _, err = writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
- updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "failed").Inc()
+ updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "failed").
+ Inc()
+
return false
}
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "update").
Msg("Updated Map has been sent")
- updateRequestsSentToNode.WithLabelValues(m.Name, m.Namespace.Name, "success").Inc()
+ updateRequestsSentToNode.WithLabelValues(machine.Name, machine.Namespace.Name, "success").
+ Inc()
// Keep track of the last successful update,
// we sometimes end in a state were the update
@@ -366,77 +405,79 @@ func (h *Headscale) PollNetMapStream(
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
- err = h.UpdateMachine(m)
+ err = h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
- lastStateUpdate.WithLabelValues(m.Namespace.Name, m.Name).Set(float64(now.Unix()))
- m.LastSuccessfulUpdate = &now
+ lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Name).
+ Set(float64(now.Unix()))
+ machine.LastSuccessfulUpdate = &now
- h.db.Save(&m)
+ h.db.Save(&machine)
} else {
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
- Time("last_successful_update", *m.LastSuccessfulUpdate).
- Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
- Msgf("%s is up to date", m.Name)
+ Str("machine", machine.Name).
+ Time("last_successful_update", *machine.LastSuccessfulUpdate).
+ Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
+ Msgf("%s is up to date", machine.Name)
}
+
return true
- case <-c.Request.Context().Done():
+ case <-ctx.Request.Context().Done():
log.Info().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
- err := h.UpdateMachine(m)
+ err := h.UpdateMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
- m.LastSeen = &now
- h.db.Save(&m)
+ machine.LastSeen = &now
+ h.db.Save(&machine)
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "Done").
Msg("Cancelling keepAlive channel")
cancelKeepAlive <- struct{}{}
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing update channel")
- //h.closeUpdateChannel(m)
+ // h.closeUpdateChannel(m)
close(updateChan)
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing pollData channel")
close(pollDataChan)
log.Trace().
Str("handler", "PollNetMapStream").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Str("channel", "Done").
Msg("Closing keepAliveChan channel")
close(keepAliveChan)
@@ -450,12 +491,12 @@ func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{},
updateChan chan<- struct{},
keepAliveChan chan<- []byte,
- mKey wgkey.Key,
- req tailcfg.MapRequest,
- m *Machine,
+ machineKey wgkey.Key,
+ mapRequest tailcfg.MapRequest,
+ machine *Machine,
) {
- keepAliveTicker := time.NewTicker(60 * time.Second)
- updateCheckerTicker := time.NewTicker(10 * time.Second)
+ keepAliveTicker := time.NewTicker(keepAliveInterval)
+ updateCheckerTicker := time.NewTicker(updateCheckInterval)
for {
select {
@@ -463,27 +504,29 @@ func (h *Headscale) scheduledPollWorker(
return
case <-keepAliveTicker.C:
- data, err := h.getMapKeepAliveResponse(mKey, req, m)
+ data, err := h.getMapKeepAliveResponse(machineKey, mapRequest)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
+
return
}
log.Debug().
Str("func", "keepAlive").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Sending keepalive")
keepAliveChan <- data
case <-updateCheckerTicker.C:
log.Debug().
Str("func", "scheduledPollWorker").
- Str("machine", m.Name).
+ Str("machine", machine.Name).
Msg("Sending update request")
- updateRequestsFromNode.WithLabelValues(m.Name, m.Namespace.Name, "scheduled-update").Inc()
+ updateRequestsFromNode.WithLabelValues(machine.Name, machine.Namespace.Name, "scheduled-update").
+ Inc()
updateChan <- struct{}{}
}
}
diff --git a/preauth_keys.go b/preauth_keys.go
index 41b10e3..50bc474 100644
--- a/preauth_keys.go
+++ b/preauth_keys.go
@@ -7,19 +7,19 @@ import (
"strconv"
"time"
+ v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
-
- v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
)
const (
- errorAuthKeyNotFound = Error("AuthKey not found")
- errorAuthKeyExpired = Error("AuthKey expired")
+ errPreAuthKeyNotFound = Error("AuthKey not found")
+ errPreAuthKeyExpired = Error("AuthKey expired")
errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
+ errNamespaceMismatch = Error("namespace mismatch")
)
-// PreAuthKey describes a pre-authorization key usable in a particular namespace
+// PreAuthKey describes a pre-authorization key usable in a particular namespace.
type PreAuthKey struct {
ID uint64 `gorm:"primary_key"`
Key string
@@ -33,14 +33,14 @@ type PreAuthKey struct {
Expiration *time.Time
}
-// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it
+// CreatePreAuthKey creates a new PreAuthKey in a namespace, and returns it.
func (h *Headscale) CreatePreAuthKey(
namespaceName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
) (*PreAuthKey, error) {
- n, err := h.GetNamespace(namespaceName)
+ namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
@@ -51,35 +51,36 @@ func (h *Headscale) CreatePreAuthKey(
return nil, err
}
- k := PreAuthKey{
+ key := PreAuthKey{
Key: kstr,
- NamespaceID: n.ID,
- Namespace: *n,
+ NamespaceID: namespace.ID,
+ Namespace: *namespace,
Reusable: reusable,
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
}
- h.db.Save(&k)
+ h.db.Save(&key)
- return &k, nil
+ return &key, nil
}
-// ListPreAuthKeys returns the list of PreAuthKeys for a namespace
+// ListPreAuthKeys returns the list of PreAuthKeys for a namespace.
func (h *Headscale) ListPreAuthKeys(namespaceName string) ([]PreAuthKey, error) {
- n, err := h.GetNamespace(namespaceName)
+ namespace, err := h.GetNamespace(namespaceName)
if err != nil {
return nil, err
}
keys := []PreAuthKey{}
- if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: n.ID}).Find(&keys).Error; err != nil {
+ if err := h.db.Preload("Namespace").Where(&PreAuthKey{NamespaceID: namespace.ID}).Find(&keys).Error; err != nil {
return nil, err
}
+
return keys, nil
}
-// GetPreAuthKey returns a PreAuthKey for a given key
+// GetPreAuthKey returns a PreAuthKey for a given key.
func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, error) {
pak, err := h.checkKeyValidity(key)
if err != nil {
@@ -87,7 +88,7 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
}
if pak.Namespace.Name != namespace {
- return nil, errors.New("Namespace mismatch")
+ return nil, errNamespaceMismatch
}
return pak, nil
@@ -95,32 +96,36 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist.
-func (h *Headscale) DestroyPreAuthKey(pak *PreAuthKey) error {
- if result := h.db.Unscoped().Delete(&pak); result.Error != nil {
+func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
+ if result := h.db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
}
return nil
}
-// MarkExpirePreAuthKey marks a PreAuthKey as expired
+// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
+
return nil
}
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
-// If returns no error and a PreAuthKey, it can be used
+// If returns no error and a PreAuthKey, it can be used.
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
pak := PreAuthKey{}
- if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, errorAuthKeyNotFound
+ if result := h.db.Preload("Namespace").First(&pak, "key = ?", k); errors.Is(
+ result.Error,
+ gorm.ErrRecordNotFound,
+ ) {
+ return nil, errPreAuthKeyNotFound
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
- return nil, errorAuthKeyExpired
+ return nil, errPreAuthKeyExpired
}
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
@@ -145,13 +150,14 @@ func (h *Headscale) generateKey() (string, error) {
if _, err := rand.Read(bytes); err != nil {
return "", err
}
+
return hex.EncodeToString(bytes), nil
}
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
Namespace: key.Namespace.Name,
- Id: strconv.FormatUint(key.ID, 10),
+ Id: strconv.FormatUint(key.ID, Base10),
Key: key.Key,
Ephemeral: key.Ephemeral,
Reusable: key.Reusable,
diff --git a/preauth_keys_test.go b/preauth_keys_test.go
index dceec00..fd0feb0 100644
--- a/preauth_keys_test.go
+++ b/preauth_keys_test.go
@@ -7,189 +7,189 @@ import (
)
func (*Suite) TestCreatePreAuthKey(c *check.C) {
- _, err := h.CreatePreAuthKey("bogus", true, false, nil)
+ _, err := app.CreatePreAuthKey("bogus", true, false, nil)
c.Assert(err, check.NotNil)
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- k, err := h.CreatePreAuthKey(n.Name, true, false, nil)
+ key, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil)
// Did we get a valid key?
- c.Assert(k.Key, check.NotNil)
- c.Assert(len(k.Key), check.Equals, 48)
+ c.Assert(key.Key, check.NotNil)
+ c.Assert(len(key.Key), check.Equals, 48)
// Make sure the Namespace association is populated
- c.Assert(k.Namespace.Name, check.Equals, n.Name)
+ c.Assert(key.Namespace.Name, check.Equals, namespace.Name)
- _, err = h.ListPreAuthKeys("bogus")
+ _, err = app.ListPreAuthKeys("bogus")
c.Assert(err, check.NotNil)
- keys, err := h.ListPreAuthKeys(n.Name)
+ keys, err := app.ListPreAuthKeys(namespace.Name)
c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1)
// Make sure the Namespace association is populated
- c.Assert((keys)[0].Namespace.Name, check.Equals, n.Name)
+ c.Assert((keys)[0].Namespace.Name, check.Equals, namespace.Name)
}
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
- n, err := h.CreateNamespace("test2")
+ namespace, err := app.CreateNamespace("test2")
c.Assert(err, check.IsNil)
now := time.Now()
- pak, err := h.CreatePreAuthKey(n.Name, true, false, &now)
+ pak, err := app.CreatePreAuthKey(namespace.Name, true, false, &now)
c.Assert(err, check.IsNil)
- p, err := h.checkKeyValidity(pak.Key)
- c.Assert(err, check.Equals, errorAuthKeyExpired)
- c.Assert(p, check.IsNil)
+ key, err := app.checkKeyValidity(pak.Key)
+ c.Assert(err, check.Equals, errPreAuthKeyExpired)
+ c.Assert(key, check.IsNil)
}
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
- p, err := h.checkKeyValidity("potatoKey")
- c.Assert(err, check.Equals, errorAuthKeyNotFound)
- c.Assert(p, check.IsNil)
+ key, err := app.checkKeyValidity("potatoKey")
+ c.Assert(err, check.Equals, errPreAuthKeyNotFound)
+ c.Assert(key, check.IsNil)
}
func (*Suite) TestValidateKeyOk(c *check.C) {
- n, err := h.CreateNamespace("test3")
+ namespace, err := app.CreateNamespace("test3")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil)
- p, err := h.checkKeyValidity(pak.Key)
+ key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
- c.Assert(p.ID, check.Equals, pak.ID)
+ c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestAlreadyUsedKey(c *check.C) {
- n, err := h.CreateNamespace("test4")
+ namespace, err := app.CreateNamespace("test4")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testest",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- p, err := h.checkKeyValidity(pak.Key)
+ key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
- c.Assert(p, check.IsNil)
+ c.Assert(key, check.IsNil)
}
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
- n, err := h.CreateNamespace("test5")
+ namespace, err := app.CreateNamespace("test5")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil)
- m := Machine{
+ machine := Machine{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testest",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- p, err := h.checkKeyValidity(pak.Key)
+ key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
- c.Assert(p.ID, check.Equals, pak.ID)
+ c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
- n, err := h.CreateNamespace("test6")
+ namespace, err := app.CreateNamespace("test6")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- p, err := h.checkKeyValidity(pak.Key)
+ key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
- c.Assert(p.ID, check.Equals, pak.ID)
+ c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestEphemeralKey(c *check.C) {
- n, err := h.CreateNamespace("test7")
+ namespace, err := app.CreateNamespace("test7")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, true, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, true, nil)
c.Assert(err, check.IsNil)
now := time.Now()
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testest",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- _, err = h.checkKeyValidity(pak.Key)
+ _, err = app.checkKeyValidity(pak.Key)
// Ephemeral keys are by definition reusable
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("test7", "testest")
+ _, err = app.GetMachine("test7", "testest")
c.Assert(err, check.IsNil)
- h.expireEphemeralNodesWorker()
+ app.expireEphemeralNodesWorker()
// The machine record should have been deleted
- _, err = h.GetMachine("test7", "testest")
+ _, err = app.GetMachine("test7", "testest")
c.Assert(err, check.NotNil)
}
func (*Suite) TestExpirePreauthKey(c *check.C) {
- n, err := h.CreateNamespace("test3")
+ namespace, err := app.CreateNamespace("test3")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, true, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, true, false, nil)
c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.IsNil)
- err = h.ExpirePreAuthKey(pak)
+ err = app.ExpirePreAuthKey(pak)
c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.NotNil)
- p, err := h.checkKeyValidity(pak.Key)
- c.Assert(err, check.Equals, errorAuthKeyExpired)
- c.Assert(p, check.IsNil)
+ key, err := app.checkKeyValidity(pak.Key)
+ c.Assert(err, check.Equals, errPreAuthKeyExpired)
+ c.Assert(key, check.IsNil)
}
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
- n, err := h.CreateNamespace("test6")
+ namespace, err := app.CreateNamespace("test6")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak.Used = true
- h.db.Save(&pak)
+ app.db.Save(&pak)
- _, err = h.checkKeyValidity(pak.Key)
+ _, err = app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
}
diff --git a/proto/headscale/v1/headscale.proto b/proto/headscale/v1/headscale.proto
index 26fe2f9..e7a63fc 100644
--- a/proto/headscale/v1/headscale.proto
+++ b/proto/headscale/v1/headscale.proto
@@ -12,115 +12,115 @@ import "headscale/v1/routes.proto";
service HeadscaleService {
// --- Namespace start ---
- rpc GetNamespace(GetNamespaceRequest) returns(GetNamespaceResponse) {
- option(google.api.http) = {
- get : "/api/v1/namespace/{name}"
+ rpc GetNamespace(GetNamespaceRequest) returns (GetNamespaceResponse) {
+ option (google.api.http) = {
+ get: "/api/v1/namespace/{name}"
};
}
- rpc CreateNamespace(CreateNamespaceRequest) returns(CreateNamespaceResponse) {
- option(google.api.http) = {
- post : "/api/v1/namespace"
- body : "*"
+ rpc CreateNamespace(CreateNamespaceRequest) returns (CreateNamespaceResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/namespace"
+ body: "*"
};
}
- rpc RenameNamespace(RenameNamespaceRequest) returns(RenameNamespaceResponse) {
- option(google.api.http) = {
- post : "/api/v1/namespace/{old_name}/rename/{new_name}"
+ rpc RenameNamespace(RenameNamespaceRequest) returns (RenameNamespaceResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/namespace/{old_name}/rename/{new_name}"
};
}
- rpc DeleteNamespace(DeleteNamespaceRequest) returns(DeleteNamespaceResponse) {
- option(google.api.http) = {
- delete : "/api/v1/namespace/{name}"
+ rpc DeleteNamespace(DeleteNamespaceRequest) returns (DeleteNamespaceResponse) {
+ option (google.api.http) = {
+ delete: "/api/v1/namespace/{name}"
};
}
- rpc ListNamespaces(ListNamespacesRequest) returns(ListNamespacesResponse) {
- option(google.api.http) = {
- get : "/api/v1/namespace"
+ rpc ListNamespaces(ListNamespacesRequest) returns (ListNamespacesResponse) {
+ option (google.api.http) = {
+ get: "/api/v1/namespace"
};
}
// --- Namespace end ---
// --- PreAuthKeys start ---
- rpc CreatePreAuthKey(CreatePreAuthKeyRequest) returns(CreatePreAuthKeyResponse) {
- option(google.api.http) = {
- post : "/api/v1/preauthkey"
- body : "*"
+ rpc CreatePreAuthKey(CreatePreAuthKeyRequest) returns (CreatePreAuthKeyResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/preauthkey"
+ body: "*"
};
}
- rpc ExpirePreAuthKey(ExpirePreAuthKeyRequest) returns(ExpirePreAuthKeyResponse) {
- option(google.api.http) = {
- post : "/api/v1/preauthkey/expire"
- body : "*"
+ rpc ExpirePreAuthKey(ExpirePreAuthKeyRequest) returns (ExpirePreAuthKeyResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/preauthkey/expire"
+ body: "*"
};
}
- rpc ListPreAuthKeys(ListPreAuthKeysRequest) returns(ListPreAuthKeysResponse) {
- option(google.api.http) = {
- get : "/api/v1/preauthkey"
+ rpc ListPreAuthKeys(ListPreAuthKeysRequest) returns (ListPreAuthKeysResponse) {
+ option (google.api.http) = {
+ get: "/api/v1/preauthkey"
};
}
// --- PreAuthKeys end ---
// --- Machine start ---
- rpc DebugCreateMachine(DebugCreateMachineRequest) returns(DebugCreateMachineResponse) {
- option(google.api.http) = {
- post : "/api/v1/debug/machine"
- body : "*"
+ rpc DebugCreateMachine(DebugCreateMachineRequest) returns (DebugCreateMachineResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/debug/machine"
+ body: "*"
};
}
- rpc GetMachine(GetMachineRequest) returns(GetMachineResponse) {
- option(google.api.http) = {
- get : "/api/v1/machine/{machine_id}"
+ rpc GetMachine(GetMachineRequest) returns (GetMachineResponse) {
+ option (google.api.http) = {
+ get: "/api/v1/machine/{machine_id}"
};
}
- rpc RegisterMachine(RegisterMachineRequest) returns(RegisterMachineResponse) {
- option(google.api.http) = {
- post : "/api/v1/machine/register"
+ rpc RegisterMachine(RegisterMachineRequest) returns (RegisterMachineResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/machine/register"
};
}
- rpc DeleteMachine(DeleteMachineRequest) returns(DeleteMachineResponse) {
- option(google.api.http) = {
- delete : "/api/v1/machine/{machine_id}"
+ rpc DeleteMachine(DeleteMachineRequest) returns (DeleteMachineResponse) {
+ option (google.api.http) = {
+ delete: "/api/v1/machine/{machine_id}"
};
}
- rpc ListMachines(ListMachinesRequest) returns(ListMachinesResponse) {
- option(google.api.http) = {
- get : "/api/v1/machine"
+ rpc ListMachines(ListMachinesRequest) returns (ListMachinesResponse) {
+ option (google.api.http) = {
+ get: "/api/v1/machine"
};
}
- rpc ShareMachine(ShareMachineRequest) returns(ShareMachineResponse) {
- option(google.api.http) = {
- post : "/api/v1/machine/{machine_id}/share/{namespace}"
+ rpc ShareMachine(ShareMachineRequest) returns (ShareMachineResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/machine/{machine_id}/share/{namespace}"
};
}
- rpc UnshareMachine(UnshareMachineRequest) returns(UnshareMachineResponse) {
- option(google.api.http) = {
- post : "/api/v1/machine/{machine_id}/unshare/{namespace}"
+ rpc UnshareMachine(UnshareMachineRequest) returns (UnshareMachineResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/machine/{machine_id}/unshare/{namespace}"
};
}
// --- Machine end ---
// --- Route start ---
- rpc GetMachineRoute(GetMachineRouteRequest) returns(GetMachineRouteResponse) {
- option(google.api.http) = {
- get : "/api/v1/machine/{machine_id}/routes"
+ rpc GetMachineRoute(GetMachineRouteRequest) returns (GetMachineRouteResponse) {
+ option (google.api.http) = {
+ get: "/api/v1/machine/{machine_id}/routes"
};
}
- rpc EnableMachineRoutes(EnableMachineRoutesRequest) returns(EnableMachineRoutesResponse) {
- option(google.api.http) = {
- post : "/api/v1/machine/{machine_id}/routes"
+ rpc EnableMachineRoutes(EnableMachineRoutesRequest) returns (EnableMachineRoutesResponse) {
+ option (google.api.http) = {
+ post: "/api/v1/machine/{machine_id}/routes"
};
}
// --- Route end ---
diff --git a/routes.go b/routes.go
index f07b709..448095a 100644
--- a/routes.go
+++ b/routes.go
@@ -2,38 +2,48 @@ package headscale
import (
"encoding/json"
- "fmt"
"gorm.io/datatypes"
"inet.af/netaddr"
)
+const (
+ errRouteIsNotAvailable = Error("route is not available")
+)
+
// Deprecated: use machine function instead
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
-// namespace and node name)
-func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
- m, err := h.GetMachine(namespace, nodeName)
+// namespace and node name).
+func (h *Headscale) GetAdvertisedNodeRoutes(
+ namespace string,
+ nodeName string,
+) (*[]netaddr.IPPrefix, error) {
+ machine, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
- hostInfo, err := m.GetHostInfo()
+ hostInfo, err := machine.GetHostInfo()
if err != nil {
return nil, err
}
+
return &hostInfo.RoutableIPs, nil
}
// Deprecated: use machine function instead
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
-// namespace and node name)
-func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
- m, err := h.GetMachine(namespace, nodeName)
+// namespace and node name).
+func (h *Headscale) GetEnabledNodeRoutes(
+ namespace string,
+ nodeName string,
+) ([]netaddr.IPPrefix, error) {
+ machine, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
- data, err := m.EnabledRoutes.MarshalJSON()
+ data, err := machine.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
@@ -57,8 +67,12 @@ func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]n
}
// Deprecated: use machine function instead
-// IsNodeRouteEnabled checks if a certain route has been enabled
-func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
+// IsNodeRouteEnabled checks if a certain route has been enabled.
+func (h *Headscale) IsNodeRouteEnabled(
+ namespace string,
+ nodeName string,
+ routeStr string,
+) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
@@ -74,14 +88,19 @@ func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeS
return true
}
}
+
return false
}
// Deprecated: use EnableRoute in machine.go
// EnableNodeRoute enables a subnet route advertised by a node (identified by
-// namespace and node name)
-func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
- m, err := h.GetMachine(namespace, nodeName)
+// namespace and node name).
+func (h *Headscale) EnableNodeRoute(
+ namespace string,
+ nodeName string,
+ routeStr string,
+) error {
+ machine, err := h.GetMachine(namespace, nodeName)
if err != nil {
return err
}
@@ -113,7 +132,7 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
}
if !available {
- return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
+ return errRouteIsNotAvailable
}
routes, err := json.Marshal(enabledRoutes)
@@ -121,10 +140,10 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
return err
}
- m.EnabledRoutes = datatypes.JSON(routes)
- h.db.Save(&m)
+ machine.EnabledRoutes = datatypes.JSON(routes)
+ h.db.Save(&machine)
- err = h.RequestMapUpdates(m.NamespaceID)
+ err = h.RequestMapUpdates(machine.NamespaceID)
if err != nil {
return err
}
diff --git a/routes_test.go b/routes_test.go
index ad16d21..18cb0ce 100644
--- a/routes_test.go
+++ b/routes_test.go
@@ -10,57 +10,60 @@ import (
)
func (s *Suite) TestGetRoutes(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("test", "test_get_route_machine")
+ _, err = app.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
c.Assert(err, check.IsNil)
- hi := tailcfg.Hostinfo{
+ hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route},
}
- hostinfo, err := json.Marshal(hi)
+ hostinfo, err := json.Marshal(hostInfo)
c.Assert(err, check.IsNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "test_get_route_machine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
+ advertisedRoutes, err := app.GetAdvertisedNodeRoutes(
+ "test",
+ "test_get_route_machine",
+ )
c.Assert(err, check.IsNil)
- c.Assert(len(*r), check.Equals, 1)
+ c.Assert(len(*advertisedRoutes), check.Equals, 1)
- err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
+ err = app.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
- err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
+ err = app.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetEnableRoutes(c *check.C) {
- n, err := h.CreateNamespace("test")
+ namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("test", "test_enable_route_machine")
+ _, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix(
@@ -73,56 +76,68 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
)
c.Assert(err, check.IsNil)
- hi := tailcfg.Hostinfo{
+ hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route, route2},
}
- hostinfo, err := json.Marshal(hi)
+ hostinfo, err := json.Marshal(hostInfo)
c.Assert(err, check.IsNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "test_enable_route_machine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
+ availableRoutes, err := app.GetAdvertisedNodeRoutes(
+ "test",
+ "test_enable_route_machine",
+ )
c.Assert(err, check.IsNil)
c.Assert(len(*availableRoutes), check.Equals, 2)
- enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
+ noEnabledRoutes, err := app.GetEnabledNodeRoutes(
+ "test",
+ "test_enable_route_machine",
+ )
c.Assert(err, check.IsNil)
- c.Assert(len(enabledRoutes), check.Equals, 0)
+ c.Assert(len(noEnabledRoutes), check.Equals, 0)
- err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
+ err = app.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
- err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
+ err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
- enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
+ enabledRoutes, err := app.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
- c.Assert(len(enabledRoutes1), check.Equals, 1)
+ c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through
- err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
+ err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
- enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
+ enableRoutesAfterDoubleApply, err := app.GetEnabledNodeRoutes(
+ "test",
+ "test_enable_route_machine",
+ )
c.Assert(err, check.IsNil)
- c.Assert(len(enabledRoutes2), check.Equals, 1)
+ c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
- err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
+ err = app.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
c.Assert(err, check.IsNil)
- enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
+ enabledRoutesWithAdditionalRoute, err := app.GetEnabledNodeRoutes(
+ "test",
+ "test_enable_route_machine",
+ )
c.Assert(err, check.IsNil)
- c.Assert(len(enabledRoutes3), check.Equals, 2)
+ c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
}
diff --git a/sharing.go b/sharing.go
index 5f6a8f4..be1689d 100644
--- a/sharing.go
+++ b/sharing.go
@@ -2,11 +2,13 @@ package headscale
import "gorm.io/gorm"
-const errorSameNamespace = Error("Destination namespace same as origin")
-const errorMachineAlreadyShared = Error("Node already shared to this namespace")
-const errorMachineNotShared = Error("Machine not shared to this namespace")
+const (
+ errSameNamespace = Error("Destination namespace same as origin")
+ errMachineAlreadyShared = Error("Node already shared to this namespace")
+ errMachineNotShared = Error("Machine not shared to this namespace")
+)
-// SharedMachine is a join table to support sharing nodes between namespaces
+// SharedMachine is a join table to support sharing nodes between namespaces.
type SharedMachine struct {
gorm.Model
MachineID uint64
@@ -15,49 +17,57 @@ type SharedMachine struct {
Namespace Namespace
}
-// AddSharedMachineToNamespace adds a machine as a shared node to a namespace
-func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
- if m.NamespaceID == ns.ID {
- return errorSameNamespace
+// AddSharedMachineToNamespace adds a machine as a shared node to a namespace.
+func (h *Headscale) AddSharedMachineToNamespace(
+ machine *Machine,
+ namespace *Namespace,
+) error {
+ if machine.NamespaceID == namespace.ID {
+ return errSameNamespace
}
sharedMachines := []SharedMachine{}
- if err := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Find(&sharedMachines).Error; err != nil {
+ if err := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).Find(&sharedMachines).Error; err != nil {
return err
}
if len(sharedMachines) > 0 {
- return errorMachineAlreadyShared
+ return errMachineAlreadyShared
}
sharedMachine := SharedMachine{
- MachineID: m.ID,
- Machine: *m,
- NamespaceID: ns.ID,
- Namespace: *ns,
+ MachineID: machine.ID,
+ Machine: *machine,
+ NamespaceID: namespace.ID,
+ Namespace: *namespace,
}
h.db.Save(&sharedMachine)
return nil
}
-// RemoveSharedMachineFromNamespace removes a shared machine from a namespace
-func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace) error {
- if m.NamespaceID == ns.ID {
+// RemoveSharedMachineFromNamespace removes a shared machine from a namespace.
+func (h *Headscale) RemoveSharedMachineFromNamespace(
+ machine *Machine,
+ namespace *Namespace,
+) error {
+ if machine.NamespaceID == namespace.ID {
// Can't unshare from primary namespace
- return errorMachineNotShared
+ return errMachineNotShared
}
sharedMachine := SharedMachine{}
- result := h.db.Where("machine_id = ? AND namespace_id = ?", m.ID, ns.ID).Unscoped().Delete(&sharedMachine)
+ result := h.db.Where("machine_id = ? AND namespace_id = ?", machine.ID, namespace.ID).
+ Unscoped().
+ Delete(&sharedMachine)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
- return errorMachineNotShared
+ return errMachineNotShared
}
- err := h.RequestMapUpdates(ns.ID)
+ err := h.RequestMapUpdates(namespace.ID)
if err != nil {
return err
}
@@ -65,10 +75,10 @@ func (h *Headscale) RemoveSharedMachineFromNamespace(m *Machine, ns *Namespace)
return nil
}
-// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces
-func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
+// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces.
+func (h *Headscale) RemoveSharedMachineFromAllNamespaces(machine *Machine) error {
sharedMachine := SharedMachine{}
- if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
+ if result := h.db.Where("machine_id = ?", machine.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
return result.Error
}
diff --git a/sharing_test.go b/sharing_test.go
index 4d9e409..7ec1b0e 100644
--- a/sharing_test.go
+++ b/sharing_test.go
@@ -4,45 +4,48 @@ import (
"gopkg.in/check.v1"
)
-func CreateNodeNamespace(c *check.C, namespace, node, key, IP string) (*Namespace, *Machine) {
- n1, err := h.CreateNamespace(namespace)
+func CreateNodeNamespace(
+ c *check.C,
+ namespaceName, node, key, ip string,
+) (*Namespace, *Machine) {
+ namespace, err := app.CreateNamespace(namespaceName)
c.Assert(err, check.IsNil)
- pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
+ pak1, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine(n1.Name, node)
+ _, err = app.GetMachine(namespace.Name, node)
c.Assert(err, check.NotNil)
- m1 := &Machine{
+ machine := &Machine{
ID: 0,
MachineKey: key,
NodeKey: key,
DiscoKey: key,
Name: node,
- NamespaceID: n1.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
- IPAddress: IP,
+ IPAddress: ip,
AuthKeyID: uint(pak1.ID),
}
- h.db.Save(m1)
+ app.db.Save(machine)
- _, err = h.GetMachine(n1.Name, m1.Name)
+ _, err = app.GetMachine(namespace.Name, machine.Name)
c.Assert(err, check.IsNil)
- return n1, m1
+ return namespace, machine
}
func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
- n1, m1 := CreateNodeNamespace(
+ namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
- _, m2 := CreateNodeNamespace(
+ _, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
@@ -50,21 +53,21 @@ func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
"100.64.0.2",
)
- p1s, err := h.getPeers(m1)
+ peersOfMachine1BeforeShared, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 0)
+ c.Assert(len(peersOfMachine1BeforeShared), check.Equals, 0)
- err = h.AddSharedMachineToNamespace(m2, n1)
+ err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
- p1sAfter, err := h.getPeers(m1)
+ peersOfMachine1AfterShared, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1sAfter), check.Equals, 1)
- c.Assert(p1sAfter[0].ID, check.Equals, m2.ID)
+ c.Assert(len(peersOfMachine1AfterShared), check.Equals, 1)
+ c.Assert(peersOfMachine1AfterShared[0].ID, check.Equals, machine2.ID)
}
func (s *Suite) TestSameNamespace(c *check.C) {
- n1, m1 := CreateNodeNamespace(
+ namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
@@ -72,23 +75,23 @@ func (s *Suite) TestSameNamespace(c *check.C) {
"100.64.0.1",
)
- p1s, err := h.getPeers(m1)
+ peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 0)
+ c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
- err = h.AddSharedMachineToNamespace(m1, n1)
- c.Assert(err, check.Equals, errorSameNamespace)
+ err = app.AddSharedMachineToNamespace(machine1, namespace1)
+ c.Assert(err, check.Equals, errSameNamespace)
}
func (s *Suite) TestUnshare(c *check.C) {
- n1, m1 := CreateNodeNamespace(
+ namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_unshare_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
- _, m2 := CreateNodeNamespace(
+ _, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_unshare_2",
@@ -96,40 +99,40 @@ func (s *Suite) TestUnshare(c *check.C) {
"100.64.0.2",
)
- p1s, err := h.getPeers(m1)
+ peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 0)
+ c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
- err = h.AddSharedMachineToNamespace(m2, n1)
+ err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
- p1s, err = h.getShared(m1)
+ peersOfMachine1BeforeShare, err = app.getShared(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 1)
+ c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1)
- err = h.RemoveSharedMachineFromNamespace(m2, n1)
+ err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
- p1s, err = h.getShared(m1)
+ peersOfMachine1BeforeShare, err = app.getShared(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 0)
+ c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
- err = h.RemoveSharedMachineFromNamespace(m2, n1)
- c.Assert(err, check.Equals, errorMachineNotShared)
+ err = app.RemoveSharedMachineFromNamespace(machine2, namespace1)
+ c.Assert(err, check.Equals, errMachineNotShared)
- err = h.RemoveSharedMachineFromNamespace(m1, n1)
- c.Assert(err, check.Equals, errorMachineNotShared)
+ err = app.RemoveSharedMachineFromNamespace(machine1, namespace1)
+ c.Assert(err, check.Equals, errMachineNotShared)
}
func (s *Suite) TestAlreadyShared(c *check.C) {
- n1, m1 := CreateNodeNamespace(
+ namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
- _, m2 := CreateNodeNamespace(
+ _, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
@@ -137,25 +140,25 @@ func (s *Suite) TestAlreadyShared(c *check.C) {
"100.64.0.2",
)
- p1s, err := h.getPeers(m1)
+ peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 0)
+ c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
- err = h.AddSharedMachineToNamespace(m2, n1)
+ err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
- err = h.AddSharedMachineToNamespace(m2, n1)
- c.Assert(err, check.Equals, errorMachineAlreadyShared)
+ err = app.AddSharedMachineToNamespace(machine2, namespace1)
+ c.Assert(err, check.Equals, errMachineAlreadyShared)
}
func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
- n1, m1 := CreateNodeNamespace(
+ namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
- _, m2 := CreateNodeNamespace(
+ _, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
@@ -163,35 +166,35 @@ func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
"100.64.0.2",
)
- p1s, err := h.getPeers(m1)
+ peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 0)
+ c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 0)
- err = h.AddSharedMachineToNamespace(m2, n1)
+ err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
- p1sAfter, err := h.getPeers(m1)
+ peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1sAfter), check.Equals, 1)
- c.Assert(p1sAfter[0].Name, check.Equals, "test_get_shared_nodes_2")
+ c.Assert(len(peersOfMachine1AfterShare), check.Equals, 1)
+ c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, "test_get_shared_nodes_2")
}
func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
- n1, m1 := CreateNodeNamespace(
+ namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
- _, m2 := CreateNodeNamespace(
+ _, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
"100.64.0.2",
)
- _, m3 := CreateNodeNamespace(
+ _, machine3 := CreateNodeNamespace(
c,
"shared3",
"test_get_shared_nodes_3",
@@ -199,76 +202,80 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
"100.64.0.3",
)
- pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
+ pak4, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
c.Assert(err, check.IsNil)
- m4 := &Machine{
+ machine4 := &Machine{
ID: 4,
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
Name: "test_get_shared_nodes_4",
- NamespaceID: n1.ID,
+ NamespaceID: namespace1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4.ID),
}
- h.db.Save(m4)
+ app.db.Save(machine4)
- _, err = h.GetMachine(n1.Name, m4.Name)
+ _, err = app.GetMachine(namespace1.Name, machine4.Name)
c.Assert(err, check.IsNil)
- p1s, err := h.getPeers(m1)
+ peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 1) // node1 can see node4
- c.Assert(p1s[0].Name, check.Equals, m4.Name)
+ c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // node1 can see node4
+ c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
- err = h.AddSharedMachineToNamespace(m2, n1)
+ err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
- p1sAfter, err := h.getPeers(m1)
- c.Assert(err, check.IsNil)
- c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace)
- c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
- c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
-
- node1shared, err := h.getShared(m1)
- c.Assert(err, check.IsNil)
- c.Assert(len(node1shared), check.Equals, 1) // node1 can see node2 as shared
- c.Assert(node1shared[0].Name, check.Equals, m2.Name)
-
- pAlone, err := h.getPeers(m3)
- c.Assert(err, check.IsNil)
- c.Assert(len(pAlone), check.Equals, 0) // node3 is alone
-
- pSharedTo, err := h.getPeers(m2)
+ peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
c.Assert(
- len(pSharedTo),
+ len(peersOfMachine1AfterShare),
+ check.Equals,
+ 2,
+ ) // node1 can see node2 (shared) and node4 (same namespace)
+ c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
+ c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
+
+ sharedOfMachine1, err := app.getShared(machine1)
+ c.Assert(err, check.IsNil)
+ c.Assert(len(sharedOfMachine1), check.Equals, 1) // node1 can see node2 as shared
+ c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
+
+ peersOfMachine3, err := app.getPeers(machine3)
+ c.Assert(err, check.IsNil)
+ c.Assert(len(peersOfMachine3), check.Equals, 0) // node3 is alone
+
+ peersOfMachine2, err := app.getPeers(machine2)
+ c.Assert(err, check.IsNil)
+ c.Assert(
+ len(peersOfMachine2),
check.Equals,
2,
) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1
- c.Assert(pSharedTo[0].Name, check.Equals, m1.Name)
- c.Assert(pSharedTo[1].Name, check.Equals, m4.Name)
+ c.Assert(peersOfMachine2[0].Name, check.Equals, machine1.Name)
+ c.Assert(peersOfMachine2[1].Name, check.Equals, machine4.Name)
}
func (s *Suite) TestDeleteSharedMachine(c *check.C) {
- n1, m1 := CreateNodeNamespace(
+ namespace1, machine1 := CreateNodeNamespace(
c,
"shared1",
"test_get_shared_nodes_1",
"686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
"100.64.0.1",
)
- _, m2 := CreateNodeNamespace(
+ _, machine2 := CreateNodeNamespace(
c,
"shared2",
"test_get_shared_nodes_2",
"dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
"100.64.0.2",
)
- _, m3 := CreateNodeNamespace(
+ _, machine3 := CreateNodeNamespace(
c,
"shared3",
"test_get_shared_nodes_3",
@@ -276,56 +283,58 @@ func (s *Suite) TestDeleteSharedMachine(c *check.C) {
"100.64.0.3",
)
- pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
+ pak4n1, err := app.CreatePreAuthKey(namespace1.Name, false, false, nil)
c.Assert(err, check.IsNil)
- m4 := &Machine{
+ machine4 := &Machine{
ID: 4,
MachineKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
NodeKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
DiscoKey: "4c3e07c3ecd40e9c945bb6797557c451850691c0409740578325e17009dd298f",
Name: "test_get_shared_nodes_4",
- NamespaceID: n1.ID,
+ NamespaceID: namespace1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID),
}
- h.db.Save(m4)
+ app.db.Save(machine4)
- _, err = h.GetMachine(n1.Name, m4.Name)
+ _, err = app.GetMachine(namespace1.Name, machine4.Name)
c.Assert(err, check.IsNil)
- p1s, err := h.getPeers(m1)
+ peersOfMachine1BeforeShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4
- c.Assert(p1s[0].Name, check.Equals, m4.Name)
+ c.Assert(len(peersOfMachine1BeforeShare), check.Equals, 1) // nodes 1 and 4
+ c.Assert(peersOfMachine1BeforeShare[0].Name, check.Equals, machine4.Name)
- err = h.AddSharedMachineToNamespace(m2, n1)
+ err = app.AddSharedMachineToNamespace(machine2, namespace1)
c.Assert(err, check.IsNil)
- p1sAfter, err := h.getPeers(m1)
+ peersOfMachine1AfterShare, err := app.getPeers(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(p1sAfter), check.Equals, 2) // nodes 1, 2, 4
- c.Assert(p1sAfter[0].Name, check.Equals, m2.Name)
- c.Assert(p1sAfter[1].Name, check.Equals, m4.Name)
+ c.Assert(len(peersOfMachine1AfterShare), check.Equals, 2) // nodes 1, 2, 4
+ c.Assert(peersOfMachine1AfterShare[0].Name, check.Equals, machine2.Name)
+ c.Assert(peersOfMachine1AfterShare[1].Name, check.Equals, machine4.Name)
- node1shared, err := h.getShared(m1)
+ sharedOfMachine1, err := app.getShared(machine1)
c.Assert(err, check.IsNil)
- c.Assert(len(node1shared), check.Equals, 1) // nodes 1, 2, 4
- c.Assert(node1shared[0].Name, check.Equals, m2.Name)
+ c.Assert(len(sharedOfMachine1), check.Equals, 1) // nodes 1, 2, 4
+ c.Assert(sharedOfMachine1[0].Name, check.Equals, machine2.Name)
- pAlone, err := h.getPeers(m3)
+ peersOfMachine3, err := app.getPeers(machine3)
c.Assert(err, check.IsNil)
- c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone
+ c.Assert(len(peersOfMachine3), check.Equals, 0) // node 3 is alone
- sharedMachines, err := h.ListSharedMachinesInNamespace(n1.Name)
+ sharedMachinesInNamespace1, err := app.ListSharedMachinesInNamespace(
+ namespace1.Name,
+ )
c.Assert(err, check.IsNil)
- c.Assert(len(sharedMachines), check.Equals, 1)
+ c.Assert(len(sharedMachinesInNamespace1), check.Equals, 1)
- err = h.DeleteMachine(m2)
+ err = app.DeleteMachine(machine2)
c.Assert(err, check.IsNil)
- sharedMachines, err = h.ListSharedMachinesInNamespace(n1.Name)
+ sharedMachinesInNamespace1, err = app.ListSharedMachinesInNamespace(namespace1.Name)
c.Assert(err, check.IsNil)
- c.Assert(len(sharedMachines), check.Equals, 0)
+ c.Assert(len(sharedMachinesInNamespace1), check.Equals, 0)
}
diff --git a/swagger.go b/swagger.go
index 17f5769..9e62d39 100644
--- a/swagger.go
+++ b/swagger.go
@@ -6,16 +6,15 @@ import (
"net/http"
"text/template"
- "github.com/rs/zerolog/log"
-
"github.com/gin-gonic/gin"
+ "github.com/rs/zerolog/log"
)
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
var apiV1JSON []byte
-func SwaggerUI(c *gin.Context) {
- t := template.Must(template.New("swagger").Parse(`
+func SwaggerUI(ctx *gin.Context) {
+ swaggerTemplate := template.Must(template.New("swagger").Parse(`
@@ -48,18 +47,23 @@ func SwaggerUI(c *gin.Context) {
`))
var payload bytes.Buffer
- if err := t.Execute(&payload, struct{}{}); err != nil {
+ if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
log.Error().
Caller().
Err(err).
Msg("Could not render Swagger")
- c.Data(http.StatusInternalServerError, "text/html; charset=utf-8", []byte("Could not render Swagger"))
+ ctx.Data(
+ http.StatusInternalServerError,
+ "text/html; charset=utf-8",
+ []byte("Could not render Swagger"),
+ )
+
return
}
- c.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
+ ctx.Data(http.StatusOK, "text/html; charset=utf-8", payload.Bytes())
}
-func SwaggerAPIv1(c *gin.Context) {
- c.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
+func SwaggerAPIv1(ctx *gin.Context) {
+ ctx.Data(http.StatusOK, "application/json; charset=utf-8", apiV1JSON)
}
diff --git a/utils.go b/utils.go
index ad0a72d..9f7849e 100644
--- a/utils.go
+++ b/utils.go
@@ -20,31 +20,48 @@ import (
"tailscale.com/types/wgkey"
)
+const (
+ errCannotDecryptReponse = Error("cannot decrypt response")
+ errResponseMissingNonce = Error("response missing nonce")
+ errCouldNotAllocateIP = Error("could not find any suitable IP")
+)
+
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
type Error string
func (e Error) Error() string { return string(e) }
-func decode(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
+func decode(
+ msg []byte,
+ v interface{},
+ pubKey *wgkey.Key,
+ privKey *wgkey.Private,
+) error {
return decodeMsg(msg, v, pubKey, privKey)
}
-func decodeMsg(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
+func decodeMsg(
+ msg []byte,
+ output interface{},
+ pubKey *wgkey.Key,
+ privKey *wgkey.Private,
+) error {
decrypted, err := decryptMsg(msg, pubKey, privKey)
if err != nil {
return err
}
// fmt.Println(string(decrypted))
- if err := json.Unmarshal(decrypted, v); err != nil {
- return fmt.Errorf("response: %v", err)
+ if err := json.Unmarshal(decrypted, output); err != nil {
+ return err
}
+
return nil
}
func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
var nonce [24]byte
if len(msg) < len(nonce)+1 {
- return nil, fmt.Errorf("response missing nonce, len=%d", len(msg))
+ return nil, errResponseMissingNonce
}
copy(nonce[:], msg)
msg = msg[len(nonce):]
@@ -52,8 +69,9 @@ func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte,
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
if !ok {
- return nil, fmt.Errorf("cannot decrypt response")
+ return nil, errCannotDecryptReponse
}
+
return decrypted, nil
}
@@ -66,13 +84,18 @@ func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, e
return encodeMsg(b, pubKey, privKey)
}
-func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
+func encodeMsg(
+ payload []byte,
+ pubKey *wgkey.Key,
+ privKey *wgkey.Private,
+) ([]byte, error) {
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
panic(err)
}
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
- msg := box.Seal(nonce[:], b, &nonce, pub, pri)
+ msg := box.Seal(nonce[:], payload, &nonce, pub, pri)
+
return msg, nil
}
@@ -89,7 +112,7 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
for {
if !ipPrefix.Contains(ip) {
- return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
+ return nil, errCouldNotAllocateIP
}
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
@@ -98,13 +121,14 @@ func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next()
+
continue
}
if ip.IsZero() &&
ip.IsLoopback() {
-
ip = ip.Next()
+
continue
}
@@ -125,7 +149,7 @@ func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
if addr != "" {
ip, err := netaddr.ParseIP(addr)
if err != nil {
- return nil, fmt.Errorf("failed to parse ip from database, %w", err)
+ return nil, fmt.Errorf("failed to parse ip from database: %w", err)
}
ips[index] = ip
@@ -156,11 +180,16 @@ func tailNodesToString(nodes []*tailcfg.Node) string {
}
func tailMapResponseToString(resp tailcfg.MapResponse) string {
- return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
+ return fmt.Sprintf(
+ "{ Node: %s, Peers: %s }",
+ resp.Node.Name,
+ tailNodesToString(resp.Peers),
+ )
}
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
+
return d.DialContext(ctx, "unix", addr)
}
@@ -174,7 +203,7 @@ func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
return result
}
-func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
+func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
result := make([]netaddr.IPPrefix, len(prefixes))
for index, prefixStr := range prefixes {
@@ -189,7 +218,7 @@ func stringToIpPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
return result, nil
}
-func containsIpPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
+func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
for _, p := range prefixes {
if prefix == p {
return true
diff --git a/utils_test.go b/utils_test.go
index f50cd11..dcda613 100644
--- a/utils_test.go
+++ b/utils_test.go
@@ -6,7 +6,7 @@ import (
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
- ip, err := h.getAvailableIP()
+ ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
@@ -16,33 +16,33 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
}
func (s *Suite) TestGetUsedIps(c *check.C) {
- ip, err := h.getAvailableIP()
+ ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
- n, err := h.CreateNamespace("test_ip")
+ namespace, err := app.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("test", "testmachine")
+ _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- ips, err := h.getUsedIPs()
+ ips, err := app.getUsedIPs()
c.Assert(err, check.IsNil)
@@ -50,42 +50,42 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(ips[0], check.Equals, expected)
- m1, err := h.GetMachineByID(0)
+ machine1, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil)
- c.Assert(m1.IPAddress, check.Equals, expected.String())
+ c.Assert(machine1.IPAddress, check.Equals, expected.String())
}
func (s *Suite) TestGetMultiIp(c *check.C) {
- n, err := h.CreateNamespace("test-ip-multi")
+ namespace, err := app.CreateNamespace("test-ip-multi")
c.Assert(err, check.IsNil)
- for i := 1; i <= 350; i++ {
- ip, err := h.getAvailableIP()
+ for index := 1; index <= 350; index++ {
+ ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("test", "testmachine")
+ _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
- m := Machine{
- ID: uint64(i),
+ machine := Machine{
+ ID: uint64(index),
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
}
- ips, err := h.getUsedIPs()
+ ips, err := app.getUsedIPs()
c.Assert(err, check.IsNil)
@@ -96,59 +96,67 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
// Check that we can read back the IPs
- m1, err := h.GetMachineByID(1)
+ machine1, err := app.GetMachineByID(1)
c.Assert(err, check.IsNil)
- c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String())
+ c.Assert(
+ machine1.IPAddress,
+ check.Equals,
+ netaddr.MustParseIP("10.27.0.1").String(),
+ )
- m50, err := h.GetMachineByID(50)
+ machine50, err := app.GetMachineByID(50)
c.Assert(err, check.IsNil)
- c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String())
+ c.Assert(
+ machine50.IPAddress,
+ check.Equals,
+ netaddr.MustParseIP("10.27.0.50").String(),
+ )
expectedNextIP := netaddr.MustParseIP("10.27.1.97")
- nextIP, err := h.getAvailableIP()
+ nextIP, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
- nextIP2, err := h.getAvailableIP()
+ nextIP2, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
}
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
- ip, err := h.getAvailableIP()
+ ip, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String())
- n, err := h.CreateNamespace("test_ip")
+ namespace, err := app.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
- pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
+ pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
- _, err = h.GetMachine("test", "testmachine")
+ _, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
- m := Machine{
+ machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
- NamespaceID: n.ID,
+ NamespaceID: namespace.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
- h.db.Save(&m)
+ app.db.Save(&machine)
- ip2, err := h.getAvailableIP()
+ ip2, err := app.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(ip2.String(), check.Equals, expected.String())